From c3b4efc89e925f5ca149e532fab7149a46a608a5 Mon Sep 17 00:00:00 2001 From: George Date: Mon, 5 Jun 2023 17:25:51 +0800 Subject: [PATCH 001/180] updated CMakeLists.txt and added JNI implementation to support building this library as a dependency in Android Studio with NDK --- CMakeLists.txt | 271 +++++++++++++++++++++++++++---------------------- llama-jni.cpp | 12 +++ 2 files changed, 162 insertions(+), 121 deletions(-) create mode 100644 llama-jni.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f2e78c0ffba4..e06046c40cd19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ project("llama.cpp" C CXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) +if(NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() @@ -13,69 +13,69 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(LLAMA_STANDALONE ON) - # configure project version - # TODO +# configure project version +# TODO else() set(LLAMA_STANDALONE OFF) endif() -if (EMSCRIPTEN) +if(EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) else() - if (MINGW) + if(MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) else() set(BUILD_SHARED_LIBS_DEFAULT ON) endif() endif() - # # Option list # # general -option(LLAMA_STATIC "llama: static link libraries" OFF) -option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) -option(LLAMA_LTO "llama: enable link time optimization" OFF) +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) # debug -option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) -option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) -option(LLAMA_GPROF "llama: enable gprof" OFF) +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) # sanitizers -option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) -option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) -option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) # instruction set specific -option(LLAMA_AVX "llama: enable AVX" ON) -option(LLAMA_AVX2 "llama: enable AVX2" ON) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ON) +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) + # in MSVC F16C is implied with AVX2/AVX512 -if (NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ON) +if(NOT MSVC) + option(LLAMA_F16C "llama: enable F16C" ON) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") -set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) -option(LLAMA_METAL "llama: use Metal" OFF) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_METAL "llama: use Metal" OFF) -option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_SERVER "llama: build server example" OFF) +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" OFF) # # Build info header @@ -113,7 +113,6 @@ endif() # # Compile flags # - set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) @@ -121,26 +120,27 @@ set(CMAKE_C_STANDARD_REQUIRED true) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) -if (NOT MSVC) - if (LLAMA_SANITIZE_THREAD) +if(NOT MSVC) + if(LLAMA_SANITIZE_THREAD) add_compile_options(-fsanitize=thread) link_libraries(-fsanitize=thread) endif() - if (LLAMA_SANITIZE_ADDRESS) + if(LLAMA_SANITIZE_ADDRESS) add_compile_options(-fsanitize=address -fno-omit-frame-pointer) link_libraries(-fsanitize=address) endif() - if (LLAMA_SANITIZE_UNDEFINED) + if(LLAMA_SANITIZE_UNDEFINED) add_compile_options(-fsanitize=undefined) link_libraries(-fsanitize=undefined) endif() endif() -if (APPLE AND LLAMA_ACCELERATE) +if(APPLE AND LLAMA_ACCELERATE) find_library(ACCELERATE_FRAMEWORK Accelerate) - if (ACCELERATE_FRAMEWORK) + + if(ACCELERATE_FRAMEWORK) message(STATUS "Accelerate framework found") add_compile_definitions(GGML_USE_ACCELERATE) @@ -150,16 +150,19 @@ if (APPLE AND LLAMA_ACCELERATE) endif() endif() -if (LLAMA_BLAS) - if (LLAMA_STATIC) +if(LLAMA_BLAS) + if(LLAMA_STATIC) set(BLA_STATIC ON) endif() - if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) + + if($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) set(BLA_SIZEOF_INTEGER 8) endif() + set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) find_package(BLAS) - if (BLAS_FOUND) + + if(BLAS_FOUND) message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") add_compile_options(${BLAS_LINKER_FLAGS}) @@ -170,16 +173,17 @@ if (LLAMA_BLAS) include_directories(${BLAS_INCLUDE_DIRS}) else() message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct LLAMA_BLAS_VENDOR") endif() endif() -if (LLAMA_CUBLAS) +if(LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) - if (CUDAToolkit_FOUND) + + if(CUDAToolkit_FOUND) message(STATUS "cuBLAS found") enable_language(CUDA) @@ -190,7 +194,7 @@ if (LLAMA_CUBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) - if (LLAMA_STATIC) + if(LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) @@ -201,10 +205,10 @@ if (LLAMA_CUBLAS) endif() endif() -if (LLAMA_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) +if(LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) @@ -213,7 +217,7 @@ if (LLAMA_METAL) add_compile_definitions(GGML_METAL_NDEBUG) # get full path to the file - #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + # add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") # copy ggml-metal.metal to bin directory configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) @@ -223,12 +227,13 @@ if (LLAMA_METAL) ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK} ${METALPERFORMANCE_FRAMEWORK} - ) + ) endif() -if (LLAMA_CLBLAST) +if(LLAMA_CLBLAST) find_package(CLBlast) - if (CLBlast_FOUND) + + if(CLBlast_FOUND) message(STATUS "CLBlast found") set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) @@ -241,8 +246,8 @@ if (LLAMA_CLBLAST) endif() endif() -if (LLAMA_ALL_WARNINGS) - if (NOT MSVC) +if(LLAMA_ALL_WARNINGS) + if(NOT MSVC) set(c_flags -Wall -Wextra @@ -266,24 +271,24 @@ if (LLAMA_ALL_WARNINGS) endif() add_compile_options( - "$<$:${c_flags}>" - "$<$:${cxx_flags}>" + "$<$:${c_flags}>" + "$<$:${cxx_flags}>" ) - endif() -if (MSVC) +if(MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) - if (BUILD_SHARED_LIBS) + if(BUILD_SHARED_LIBS) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) endif() endif() -if (LLAMA_LTO) +if(LLAMA_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT result OUTPUT output) - if (result) + + if(result) set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) else() message(WARNING "IPO is not supported: ${output}") @@ -292,99 +297,125 @@ endif() # Architecture specific # TODO: probably these flags need to be tweaked on some architectures -# feel free to update the Makefile for your architecture and send a pull request or issue +# feel free to update the Makefile for your architecture and send a pull request or issue message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") -if (NOT MSVC) - if (LLAMA_STATIC) + +if(NOT MSVC) + if(LLAMA_STATIC) add_link_options(-static) - if (MINGW) + + if(MINGW) add_link_options(-static-libgcc -static-libstdc++) endif() endif() - if (LLAMA_GPROF) + + if(LLAMA_GPROF) add_compile_options(-pg) endif() - if (LLAMA_NATIVE) + + if(LLAMA_NATIVE) add_compile_options(-march=native) endif() endif() -if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") +if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") message(STATUS "ARM detected") - if (MSVC) - # TODO: arm msvc? + + if(MSVC) + # TODO: arm msvc? else() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) - add_compile_options(-mcpu=native) + if(NOT DEFINED ANDROID_NDK) + add_compile_options(-mcpu=native) + endif() endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") + + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) + if(NOT DEFINED ANDROID_NDK) + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) + endif() endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") + + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") # Raspberry Pi 2 - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations) + if(NOT DEFINED ANDROID_NDK) + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations) + endif() endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") + + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") # Raspberry Pi 3, 4, Zero 2 (32-bit) - add_compile_options(-mfp16-format=ieee -mno-unaligned-access) + if(NOT DEFINED ANDROID_NDK) + add_compile_options(-mfp16-format=ieee -mno-unaligned-access) + endif() endif() endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") +elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") message(STATUS "x86 detected") - if (MSVC) - if (LLAMA_AVX512) + + if(MSVC) + if(LLAMA_AVX512) add_compile_options($<$:/arch:AVX512>) add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific # AVX512 extensions, neither it defines the # macros corresponding to the extensions. # Do it manually. - if (LLAMA_AVX512_VBMI) + if(LLAMA_AVX512_VBMI) add_compile_definitions($<$:__AVX512VBMI__>) add_compile_definitions($<$:__AVX512VBMI__>) endif() - if (LLAMA_AVX512_VNNI) + + if(LLAMA_AVX512_VNNI) add_compile_definitions($<$:__AVX512VNNI__>) add_compile_definitions($<$:__AVX512VNNI__>) endif() - elseif (LLAMA_AVX2) + elseif(LLAMA_AVX2) add_compile_options($<$:/arch:AVX2>) add_compile_options($<$:/arch:AVX2>) - elseif (LLAMA_AVX) + elseif(LLAMA_AVX) add_compile_options($<$:/arch:AVX>) add_compile_options($<$:/arch:AVX>) endif() else() - if (LLAMA_F16C) + if(LLAMA_F16C) add_compile_options(-mf16c) endif() - if (LLAMA_FMA) + + if(LLAMA_FMA) add_compile_options(-mfma) endif() - if (LLAMA_AVX) + + if(LLAMA_AVX) add_compile_options(-mavx) endif() - if (LLAMA_AVX2) + + if(LLAMA_AVX2) add_compile_options(-mavx2) endif() - if (LLAMA_AVX512) + + if(LLAMA_AVX512) add_compile_options(-mavx512f) add_compile_options(-mavx512bw) endif() - if (LLAMA_AVX512_VBMI) + + if(LLAMA_AVX512_VBMI) add_compile_options(-mavx512vbmi) endif() - if (LLAMA_AVX512_VNNI) + + if(LLAMA_AVX512_VNNI) add_compile_options(-mavx512vnni) endif() endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") +elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") add_compile_options(-mcpu=native -mtune=native) - #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) + +# TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) else() message(STATUS "Unknown architecture") endif() @@ -392,59 +423,57 @@ endif() # # Build libraries # - add_library(ggml OBJECT - ggml.c - ggml.h - ${GGML_SOURCES_CUDA} - ${GGML_SOURCES_OPENCL} - ${GGML_SOURCES_METAL} - ) + ggml.c + ggml.h + ${GGML_SOURCES_CUDA} + ${GGML_SOURCES_OPENCL} + ${GGML_SOURCES_METAL} +) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) -if (BUILD_SHARED_LIBS) +if(BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() add_library(llama - llama.cpp - llama.h - llama-util.h - ) + llama.cpp + llama.h + llama-util.h + llama-jni.cpp +) target_include_directories(llama PUBLIC .) target_compile_features(llama PUBLIC cxx_std_11) # don't bump target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS} - ) +) -if (BUILD_SHARED_LIBS) +if(BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) endif() -if (GGML_SOURCES_CUDA) +if(GGML_SOURCES_CUDA) message(STATUS "GGML CUDA sources found, configuring CUDA architecture") - set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) - set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF) endif() - # # programs, examples and tests # - -if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) +if(LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) include(CTest) add_subdirectory(tests) -endif () +endif() -if (LLAMA_BUILD_EXAMPLES) +if(LLAMA_BUILD_EXAMPLES) add_subdirectory(examples) add_subdirectory(pocs) endif() diff --git a/llama-jni.cpp b/llama-jni.cpp new file mode 100644 index 0000000000000..2066eefd910f6 --- /dev/null +++ b/llama-jni.cpp @@ -0,0 +1,12 @@ +#include +#include "llama.h" + +// +// Created by gcpth on 05/06/2023. +// + +extern "C" +JNIEXPORT void JNICALL +Java_com_layla_LlamaCpp_llama_1init_1backend(JNIEnv *env, jclass clazz) { + llama_init_backend(); +} From 1855e0820e7be73957bcef42a6774ee872e5753b Mon Sep 17 00:00:00 2001 From: George Date: Mon, 5 Jun 2023 17:54:57 +0800 Subject: [PATCH 002/180] removed jni files (moved them to caller android project) --- CMakeLists.txt | 1 - llama-jni.cpp | 12 ------------ 2 files changed, 13 deletions(-) delete mode 100644 llama-jni.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e06046c40cd19..ea9a160ef014c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -443,7 +443,6 @@ add_library(llama llama.cpp llama.h llama-util.h - llama-jni.cpp ) target_include_directories(llama PUBLIC .) diff --git a/llama-jni.cpp b/llama-jni.cpp deleted file mode 100644 index 2066eefd910f6..0000000000000 --- a/llama-jni.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include -#include "llama.h" - -// -// Created by gcpth on 05/06/2023. -// - -extern "C" -JNIEXPORT void JNICALL -Java_com_layla_LlamaCpp_llama_1init_1backend(JNIEnv *env, jclass clazz) { - llama_init_backend(); -} From d84ceb8be0b71dbb383a8ebeda5e23b357f79873 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 8 Jun 2023 18:47:05 +0800 Subject: [PATCH 003/180] Added support for latest Android NDK which uses Clang instead of GCC --- CMakeLists.txt | 281 +++++++++++++++++++++++-------------------------- 1 file changed, 131 insertions(+), 150 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b05c2535e8147..45b740d735743 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ project("llama.cpp" C CXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -if(NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() @@ -13,70 +13,70 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(LLAMA_STANDALONE ON) -# configure project version -# TODO + # configure project version + # TODO else() set(LLAMA_STANDALONE OFF) endif() -if(EMSCRIPTEN) +if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) else() - if(MINGW) + if (MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) else() set(BUILD_SHARED_LIBS_DEFAULT ON) endif() endif() + # # Option list # # general -option(LLAMA_STATIC "llama: static link libraries" OFF) -option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) -option(LLAMA_LTO "llama: enable link time optimization" OFF) +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) # debug -option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) -option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) -option(LLAMA_GPROF "llama: enable gprof" OFF) +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) # sanitizers -option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) -option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) -option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) # instruction set specific -option(LLAMA_AVX "llama: enable AVX" ON) -option(LLAMA_AVX2 "llama: enable AVX2" ON) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ON) - +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) # in MSVC F16C is implied with AVX2/AVX512 -if(NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ON) +if (NOT MSVC) + option(LLAMA_F16C "llama: enable F16C" ON) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") -set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) -option(LLAMA_METAL "llama: use Metal" OFF) -option(LLAMA_K_QUANTS "llama: use k-quants" ON) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_METAL "llama: use Metal" OFF) +option(LLAMA_K_QUANTS "llama: use k-quants" ON) -option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_SERVER "llama: build server example" OFF) +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" OFF) # # Build info header @@ -114,6 +114,7 @@ endif() # # Compile flags # + set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) @@ -121,27 +122,26 @@ set(CMAKE_C_STANDARD_REQUIRED true) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) -if(NOT MSVC) - if(LLAMA_SANITIZE_THREAD) +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) add_compile_options(-fsanitize=thread) link_libraries(-fsanitize=thread) endif() - if(LLAMA_SANITIZE_ADDRESS) + if (LLAMA_SANITIZE_ADDRESS) add_compile_options(-fsanitize=address -fno-omit-frame-pointer) link_libraries(-fsanitize=address) endif() - if(LLAMA_SANITIZE_UNDEFINED) + if (LLAMA_SANITIZE_UNDEFINED) add_compile_options(-fsanitize=undefined) link_libraries(-fsanitize=undefined) endif() endif() -if(APPLE AND LLAMA_ACCELERATE) +if (APPLE AND LLAMA_ACCELERATE) find_library(ACCELERATE_FRAMEWORK Accelerate) - - if(ACCELERATE_FRAMEWORK) + if (ACCELERATE_FRAMEWORK) message(STATUS "Accelerate framework found") add_compile_definitions(GGML_USE_ACCELERATE) @@ -151,19 +151,16 @@ if(APPLE AND LLAMA_ACCELERATE) endif() endif() -if(LLAMA_BLAS) - if(LLAMA_STATIC) +if (LLAMA_BLAS) + if (LLAMA_STATIC) set(BLA_STATIC ON) endif() - - if($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) + if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) set(BLA_SIZEOF_INTEGER 8) endif() - set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) find_package(BLAS) - - if(BLAS_FOUND) + if (BLAS_FOUND) message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") add_compile_options(${BLAS_LINKER_FLAGS}) @@ -174,17 +171,16 @@ if(LLAMA_BLAS) include_directories(${BLAS_INCLUDE_DIRS}) else() message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct LLAMA_BLAS_VENDOR") endif() endif() -if(LLAMA_CUBLAS) +if (LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) - - if(CUDAToolkit_FOUND) + if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") enable_language(CUDA) @@ -195,7 +191,7 @@ if(LLAMA_CUBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) - if(LLAMA_STATIC) + if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) @@ -206,10 +202,10 @@ if(LLAMA_CUBLAS) endif() endif() -if(LLAMA_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) +if (LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) @@ -218,7 +214,7 @@ if(LLAMA_METAL) add_compile_definitions(GGML_METAL_NDEBUG) # get full path to the file - # add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") # copy ggml-metal.metal to bin directory configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) @@ -228,18 +224,17 @@ if(LLAMA_METAL) ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK} ${METALPERFORMANCE_FRAMEWORK} - ) + ) endif() -if(LLAMA_K_QUANTS) +if (LLAMA_K_QUANTS) set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) add_compile_definitions(GGML_USE_K_QUANTS) endif() -if(LLAMA_CLBLAST) +if (LLAMA_CLBLAST) find_package(CLBlast) - - if(CLBlast_FOUND) + if (CLBlast_FOUND) message(STATUS "CLBlast found") set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) @@ -252,8 +247,8 @@ if(LLAMA_CLBLAST) endif() endif() -if(LLAMA_ALL_WARNINGS) - if(NOT MSVC) +if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) set(c_flags -Wall -Wextra @@ -277,24 +272,24 @@ if(LLAMA_ALL_WARNINGS) endif() add_compile_options( - "$<$:${c_flags}>" - "$<$:${cxx_flags}>" + "$<$:${c_flags}>" + "$<$:${cxx_flags}>" ) + endif() -if(MSVC) +if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) - if(BUILD_SHARED_LIBS) + if (BUILD_SHARED_LIBS) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) endif() endif() -if(LLAMA_LTO) +if (LLAMA_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT result OUTPUT output) - - if(result) + if (result) set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) else() message(WARNING "IPO is not supported: ${output}") @@ -303,125 +298,108 @@ endif() # Architecture specific # TODO: probably these flags need to be tweaked on some architectures -# feel free to update the Makefile for your architecture and send a pull request or issue +# feel free to update the Makefile for your architecture and send a pull request or issue message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") - -if(NOT MSVC) - if(LLAMA_STATIC) +if (NOT MSVC) + if (LLAMA_STATIC) add_link_options(-static) - - if(MINGW) + if (MINGW) add_link_options(-static-libgcc -static-libstdc++) endif() endif() - - if(LLAMA_GPROF) + if (LLAMA_GPROF) add_compile_options(-pg) endif() - - if(LLAMA_NATIVE) + if (LLAMA_NATIVE) add_compile_options(-march=native) endif() endif() -if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") +if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") message(STATUS "ARM detected") - - if(MSVC) - # TODO: arm msvc? + if (MSVC) + # TODO: arm msvc? else() - if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) - if(NOT DEFINED ANDROID_NDK) + + # latest Android NDK uses CLang now, which does not support the native flag + if(DEFINED ANDROID_NDK) + else() add_compile_options(-mcpu=native) endif() endif() - - if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero - if(NOT DEFINED ANDROID_NDK) - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) - endif() + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) endif() - - if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") # Raspberry Pi 2 - if(NOT DEFINED ANDROID_NDK) - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations) - endif() + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) endif() - - if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") # Raspberry Pi 3, 4, Zero 2 (32-bit) - if(NOT DEFINED ANDROID_NDK) - add_compile_options(-mfp16-format=ieee -mno-unaligned-access) + add_compile_options(-mfp16-format=ieee -mno-unaligned-access) + + if(DEFINED ANDROID_NDK) + add_compile_options(-march=armv8.4a+dotprod) endif() endif() endif() -elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") message(STATUS "x86 detected") - - if(MSVC) - if(LLAMA_AVX512) + if (MSVC) + if (LLAMA_AVX512) add_compile_options($<$:/arch:AVX512>) add_compile_options($<$:/arch:AVX512>) - # MSVC has no compile-time flags enabling specific # AVX512 extensions, neither it defines the # macros corresponding to the extensions. # Do it manually. - if(LLAMA_AVX512_VBMI) + if (LLAMA_AVX512_VBMI) add_compile_definitions($<$:__AVX512VBMI__>) add_compile_definitions($<$:__AVX512VBMI__>) endif() - - if(LLAMA_AVX512_VNNI) + if (LLAMA_AVX512_VNNI) add_compile_definitions($<$:__AVX512VNNI__>) add_compile_definitions($<$:__AVX512VNNI__>) endif() - elseif(LLAMA_AVX2) + elseif (LLAMA_AVX2) add_compile_options($<$:/arch:AVX2>) add_compile_options($<$:/arch:AVX2>) - elseif(LLAMA_AVX) + elseif (LLAMA_AVX) add_compile_options($<$:/arch:AVX>) add_compile_options($<$:/arch:AVX>) endif() else() - if(LLAMA_F16C) + if (LLAMA_F16C) add_compile_options(-mf16c) endif() - - if(LLAMA_FMA) + if (LLAMA_FMA) add_compile_options(-mfma) endif() - - if(LLAMA_AVX) + if (LLAMA_AVX) add_compile_options(-mavx) endif() - - if(LLAMA_AVX2) + if (LLAMA_AVX2) add_compile_options(-mavx2) endif() - - if(LLAMA_AVX512) + if (LLAMA_AVX512) add_compile_options(-mavx512f) add_compile_options(-mavx512bw) endif() - - if(LLAMA_AVX512_VBMI) + if (LLAMA_AVX512_VBMI) add_compile_options(-mavx512vbmi) endif() - - if(LLAMA_AVX512_VNNI) + if (LLAMA_AVX512_VNNI) add_compile_options(-mavx512vnni) endif() endif() -elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") add_compile_options(-mcpu=native -mtune=native) - -# TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) + #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) else() message(STATUS "Unknown architecture") endif() @@ -429,57 +407,60 @@ endif() # # Build libraries # + add_library(ggml OBJECT - ggml.c - ggml.h - ${GGML_SOURCES_CUDA} - ${GGML_SOURCES_OPENCL} - ${GGML_SOURCES_METAL} - ${GGML_SOURCES_EXTRA} -) + ggml.c + ggml.h + ${GGML_SOURCES_CUDA} + ${GGML_SOURCES_OPENCL} + ${GGML_SOURCES_METAL} + ${GGML_SOURCES_EXTRA} + ) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) -if(BUILD_SHARED_LIBS) +if (BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() add_library(llama - llama.cpp - llama.h - llama-util.h -) + llama.cpp + llama.h + llama-util.h + ) target_include_directories(llama PUBLIC .) target_compile_features(llama PUBLIC cxx_std_11) # don't bump target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS} -) + ) -if(BUILD_SHARED_LIBS) +if (BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) endif() -if(GGML_SOURCES_CUDA) +if (GGML_SOURCES_CUDA) message(STATUS "GGML CUDA sources found, configuring CUDA architecture") - set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) - set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF) endif() + # # programs, examples and tests # -if(LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) + +if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) include(CTest) add_subdirectory(tests) -endif() +endif () -if(LLAMA_BUILD_EXAMPLES) +if (LLAMA_BUILD_EXAMPLES) add_subdirectory(examples) add_subdirectory(pocs) -endif() +endif() \ No newline at end of file From b9f7cf702402e08c261710913828b7dac6bee5ed Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 8 Jun 2023 20:41:00 +0800 Subject: [PATCH 004/180] added -funsafe-math-optimizations flag back to armv7 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 45b740d735743..8a4dfc6845a6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -336,7 +336,7 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") # Raspberry Pi 2 - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations) endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") # Raspberry Pi 3, 4, Zero 2 (32-bit) From 2d247e3c11f14ba85662fc1d991945501eefa1d4 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 13 Jun 2023 02:05:08 +0800 Subject: [PATCH 005/180] removed NDK check for armv8 because we are passing options from the parent project instead --- CMakeLists.txt | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8384f2acc2380..59140bb587078 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -341,10 +341,6 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") # Raspberry Pi 3, 4, Zero 2 (32-bit) add_compile_options(-mfp16-format=ieee -mno-unaligned-access) - - if(DEFINED ANDROID_NDK) - add_compile_options(-march=armv8.4a+dotprod) - endif() endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") @@ -466,4 +462,4 @@ endif () if (LLAMA_BUILD_EXAMPLES) add_subdirectory(examples) add_subdirectory(pocs) -endif() \ No newline at end of file +endif() From 9d2f4a8000b7be85a5a19e78ccbf64cff16558c0 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 14 Jun 2023 15:09:05 +0800 Subject: [PATCH 006/180] Used local copy of CLBlast instead of installed one --- CMakeLists.txt | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 59140bb587078..b5f25bf87f3e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -233,18 +233,14 @@ if (LLAMA_K_QUANTS) endif() if (LLAMA_CLBLAST) - find_package(CLBlast) - if (CLBlast_FOUND) - message(STATUS "CLBlast found") + # build CLBlast from source + add_subdirectory(../CLBlast ${CMAKE_CURRENT_BINARY_DIR}/clblast) - set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) - add_compile_definitions(GGML_USE_CLBLAST) + add_compile_definitions(GGML_USE_CLBLAST) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) - else() - message(WARNING "CLBlast not found") - endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) endif() if (LLAMA_ALL_WARNINGS) From 670390f915c50b277274e7ec60f06cc6364c3057 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sun, 18 Jun 2023 13:31:49 +0800 Subject: [PATCH 007/180] typo fix when calculation memory address for dest_t in ggml_compute_forward_add_q_f32 --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 0eda7f338e69b..ca22ba21ed0d2 100644 --- a/ggml.c +++ b/ggml.c @@ -7898,7 +7898,7 @@ static void ggml_compute_forward_add_q_f32( void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); - void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); assert(ne00 % 32 == 0); From ced8e8d40d69461925b799c42043ef2caa579fd3 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 19 Jun 2023 14:46:15 +0800 Subject: [PATCH 008/180] fixed issue: memory is not guaranteed to be aligned properly during ggml_init call from loading saved sessions --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2105e32799ae9..ce33f7fbee15f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3122,9 +3122,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); - char buffer[4096]; + //char buffer[4096]; - ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; From 012a18847d9b450af090a9fc94f3240443695706 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sat, 1 Jul 2023 10:08:14 +0800 Subject: [PATCH 009/180] bumped llama.cpp from source --- llama.cpp | 2354 ++++++++++++++++++++--------------------------------- 1 file changed, 901 insertions(+), 1453 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9c917a70a373c..97900e56a0856 100644 --- a/llama.cpp +++ b/llama.cpp @@ -50,15 +50,14 @@ #include #if defined(_MSC_VER) -#pragma warning(disable : 4244 4267) // possible loss of data +#pragma warning(disable: 4244 4267) // possible loss of data #endif #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 // available llama models -enum e_model -{ +enum e_model { MODEL_UNKNOWN, MODEL_3B, MODEL_7B, @@ -67,123 +66,116 @@ enum e_model MODEL_65B, }; -static const size_t MB = 1024 * 1024; +static const size_t MB = 1024*1024; // computed for n_ctx == 2048 // TODO: dynamically determine these sizes // needs modifications in ggml -typedef void (*offload_func_t)(struct ggml_tensor *tensor); +typedef void (*offload_func_t)(struct ggml_tensor * tensor); -void llama_nop(struct ggml_tensor *tensor) -{ // don't offload by default - (void)tensor; +void llama_nop(struct ggml_tensor * tensor) { // don't offload by default + (void) tensor; } -static const std::map &MEM_REQ_SCRATCH0() +static const std::map & MEM_REQ_SCRATCH0() { static std::map k_sizes = { - {MODEL_3B, 256ull * MB}, - {MODEL_7B, 512ull * MB}, - {MODEL_13B, 512ull * MB}, - {MODEL_30B, 512ull * MB}, - {MODEL_65B, 1024ull * MB}, + { MODEL_3B, 256ull * MB }, + { MODEL_7B, 512ull * MB }, + { MODEL_13B, 512ull * MB }, + { MODEL_30B, 512ull * MB }, + { MODEL_65B, 1024ull * MB }, }; return k_sizes; } -static const std::map &MEM_REQ_SCRATCH1() +static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { - {MODEL_3B, 256ull * MB}, - {MODEL_7B, 512ull * MB}, - {MODEL_13B, 512ull * MB}, - {MODEL_30B, 512ull * MB}, - {MODEL_65B, 1024ull * MB}, + { MODEL_3B, 256ull * MB }, + { MODEL_7B, 512ull * MB }, + { MODEL_13B, 512ull * MB }, + { MODEL_30B, 512ull * MB }, + { MODEL_65B, 1024ull * MB }, }; return k_sizes; } // 2*n_embd*n_ctx*n_layer*sizeof(float16) -static const std::map &MEM_REQ_KV_SELF() +static const std::map & MEM_REQ_KV_SELF() { static std::map k_sizes = { - {MODEL_3B, 682ull * MB}, - {MODEL_7B, 1026ull * MB}, - {MODEL_13B, 1608ull * MB}, - {MODEL_30B, 3124ull * MB}, - {MODEL_65B, 5120ull * MB}, + { MODEL_3B, 682ull * MB }, + { MODEL_7B, 1026ull * MB }, + { MODEL_13B, 1608ull * MB }, + { MODEL_30B, 3124ull * MB }, + { MODEL_65B, 5120ull * MB }, }; return k_sizes; } // this is mostly needed for temporary mul_mat buffers to dequantize the data // not actually needed if BLAS is disabled -static const std::map &MEM_REQ_EVAL() +static const std::map & MEM_REQ_EVAL() { static std::map k_sizes = { - {MODEL_3B, 512ull * MB}, - {MODEL_7B, 768ull * MB}, - {MODEL_13B, 1024ull * MB}, - {MODEL_30B, 1280ull * MB}, - {MODEL_65B, 1536ull * MB}, + { MODEL_3B, 512ull * MB }, + { MODEL_7B, 768ull * MB }, + { MODEL_13B, 1024ull * MB }, + { MODEL_30B, 1280ull * MB }, + { MODEL_65B, 1536ull * MB }, }; return k_sizes; } // default hparams (LLaMA 7B) -struct llama_hparams -{ +struct llama_hparams { uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 256; - uint32_t n_head = 32; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_mult = 256; + uint32_t n_head = 32; uint32_t n_layer = 32; - uint32_t n_rot = 64; + uint32_t n_rot = 64; enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; - bool operator!=(const llama_hparams &other) const - { + bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); } }; -struct llama_layer -{ +struct llama_layer { // normalization - struct ggml_tensor *attention_norm; + struct ggml_tensor * attention_norm; // attention - struct ggml_tensor *wq; - struct ggml_tensor *wk; - struct ggml_tensor *wv; - struct ggml_tensor *wo; + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; // normalization - struct ggml_tensor *ffn_norm; + struct ggml_tensor * ffn_norm; // ff - struct ggml_tensor *w1; - struct ggml_tensor *w2; - struct ggml_tensor *w3; + struct ggml_tensor * w1; + struct ggml_tensor * w2; + struct ggml_tensor * w3; }; -struct llama_kv_cache -{ - struct ggml_tensor *k; - struct ggml_tensor *v; +struct llama_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; - struct ggml_context *ctx = NULL; + struct ggml_context * ctx = NULL; llama_ctx_buffer buf; int n; // number of tokens currently in the cache - ~llama_kv_cache() - { - if (ctx) - { + ~llama_kv_cache() { + if (ctx) { ggml_free(ctx); } @@ -194,13 +186,11 @@ struct llama_kv_cache } }; -struct llama_vocab -{ - using id = int32_t; +struct llama_vocab { + using id = int32_t; using token = std::string; - struct token_score - { + struct token_score { token tok; float score; }; @@ -209,22 +199,21 @@ struct llama_vocab std::vector id_to_token; }; -struct llama_model -{ +struct llama_model { e_model type = MODEL_UNKNOWN; llama_hparams hparams; - struct ggml_tensor *tok_embeddings; + struct ggml_tensor * tok_embeddings; - struct ggml_tensor *norm; - struct ggml_tensor *output; + struct ggml_tensor * norm; + struct ggml_tensor * output; std::vector layers; int n_gpu_layers; // context - struct ggml_context *ctx = NULL; + struct ggml_context * ctx = NULL; // the model memory buffer llama_ctx_buffer buf; @@ -244,46 +233,41 @@ struct llama_model llama_vocab vocab; - ~llama_model() - { - if (ctx) - { + ~llama_model() { + if (ctx) { ggml_free(ctx); } #ifdef GGML_USE_CUBLAS - for (size_t i = 0; i < tensors_by_name.size(); ++i) - { + for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cuda_free_data(tensors_by_name[i].second); } ggml_cuda_free_scratch(); #elif defined(GGML_USE_CLBLAST) - for (size_t i = 0; i < tensors_by_name.size(); ++i) - { + for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cl_free_data(tensors_by_name[i].second); } #endif } }; -struct llama_context -{ - llama_context(const llama_model &model, const llama_vocab &vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} +struct llama_context { + llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} std::mt19937 rng; bool has_evaluated_once = false; int64_t t_sample_us = 0; - int64_t t_eval_us = 0; + int64_t t_eval_us = 0; int64_t t_p_eval_us = 0; int32_t n_sample = 0; // number of tokens sampled - int32_t n_eval = 0; // number of eval calls + int32_t n_eval = 0; // number of eval calls int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - const llama_model &model; - const llama_vocab &vocab; + const llama_model & model; + const llama_vocab & vocab; bool model_owner = false; @@ -308,95 +292,73 @@ struct llama_context llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; #ifdef GGML_USE_METAL - ggml_metal_context *ctx_metal = NULL; + ggml_metal_context * ctx_metal = NULL; #endif - int buf_last = 0; - size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = {0}; + int buf_last = 0; + size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; - void use_buf(struct ggml_context *ctx, int i) - { + void use_buf(struct ggml_context * ctx, int i) { #if defined(LLAMA_USE_SCRATCH) size_t last_size = 0; - if (i == -1) - { - last_size = ggml_set_scratch(ctx, { - 0, - 0, - nullptr, - }); - } - else - { - auto &buf = buf_scratch[i]; - last_size = ggml_set_scratch(ctx, { - 0, - buf.size, - buf.addr, - }); + if (i == -1) { + last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.addr, }); } - if (buf_last >= 0) - { + if (buf_last >= 0) { buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); } buf_last = i; #else - (void)i; - (void)ctx; + (void) i; + (void) ctx; #endif } - size_t get_buf_max_mem(int i) const - { + size_t get_buf_max_mem(int i) const { #if defined(LLAMA_USE_SCRATCH) return buf_max_size[i]; #else - (void)i; + (void) i; return 0; #endif } }; template -static T checked_mul(T a, T b) -{ +static T checked_mul(T a, T b) { T ret = a * b; - if (a != 0 && ret / a != b) - { + if (a != 0 && ret / a != b) { throw std::runtime_error(format("overflow multiplying %llu * %llu", - (unsigned long long)a, (unsigned long long)b)); + (unsigned long long) a, (unsigned long long) b)); } return ret; } -static size_t checked_div(size_t a, size_t b) -{ - if (b == 0 || a % b != 0) - { +static size_t checked_div(size_t a, size_t b) { + if (b == 0 || a % b != 0) { throw std::runtime_error(format("error dividing %zu / %zu", a, b)); } return a / b; } -static std::string llama_format_tensor_shape(const std::vector &ne) -{ +static std::string llama_format_tensor_shape(const std::vector & ne) { char buf[256]; snprintf(buf, sizeof(buf), "%5u", ne.at(0)); - for (size_t i = 1; i < ne.size(); i++) - { + for (size_t i = 1; i < ne.size(); i++) { snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), " x %5u", ne.at(i)); } return buf; } -static size_t llama_calc_tensor_size(const std::vector &ne, enum ggml_type type) -{ +static size_t llama_calc_tensor_size(const std::vector & ne, enum ggml_type type) { size_t size = ggml_type_size(type); - for (uint32_t dim : ne) - { + for (uint32_t dim : ne) { size = checked_mul(size, dim); } return size / ggml_blck_size(type); @@ -412,15 +374,13 @@ struct llama_load_tensor { uint8_t * data; }; -struct llama_load_tensors_map -{ +struct llama_load_tensors_map { // tensors is kept in a separate vector to preserve file order std::vector tensors; std::unordered_map name_to_idx; }; -enum llama_file_version -{ +enum llama_file_version { LLAMA_FILE_VERSION_GGML, LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab LLAMA_FILE_VERSION_GGJT_V1, // added padding @@ -428,8 +388,7 @@ enum llama_file_version LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format }; -struct llama_file_loader -{ +struct llama_file_loader { llama_file file; llama_file_version file_version; llama_hparams hparams; @@ -443,74 +402,57 @@ struct llama_file_loader read_vocab(); read_tensor_metadata(tensors_map); } - void read_magic() - { + void read_magic() { uint32_t magic = file.read_u32(); - if (magic == LLAMA_FILE_MAGIC_GGML) - { + if (magic == LLAMA_FILE_MAGIC_GGML) { file_version = LLAMA_FILE_VERSION_GGML; return; } uint32_t version = file.read_u32(); - switch (magic) - { - case LLAMA_FILE_MAGIC_GGMF: - switch (version) - { - case 1: - file_version = LLAMA_FILE_VERSION_GGMF_V1; - return; - } - break; - case LLAMA_FILE_MAGIC_GGJT: - switch (version) - { - case 1: - file_version = LLAMA_FILE_VERSION_GGJT_V1; - return; - case 2: - file_version = LLAMA_FILE_VERSION_GGJT_V2; - return; - case 3: - file_version = LLAMA_FILE_VERSION_GGJT_V3; - return; - } + switch (magic) { + case LLAMA_FILE_MAGIC_GGMF: + switch (version) { + case 1: file_version = LLAMA_FILE_VERSION_GGMF_V1; return; + } + break; + case LLAMA_FILE_MAGIC_GGJT: + switch (version) { + case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return; + case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return; + case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return; + } } throw std::runtime_error(format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", - magic, version)); + magic, version)); } - void read_hparams() - { + void read_hparams() { hparams.n_vocab = file.read_u32(); hparams.n_embd = file.read_u32(); hparams.n_mult = file.read_u32(); hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); hparams.n_rot = file.read_u32(); - hparams.ftype = (enum llama_ftype)file.read_u32(); + hparams.ftype = (enum llama_ftype) file.read_u32(); } - void read_vocab() - { + void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); - for (uint32_t i = 0; i < hparams.n_vocab; i++) - { + for (uint32_t i = 0; i < hparams.n_vocab; i++) { uint32_t len = file.read_u32(); std::string word = file.read_string(len); float score = 0.0f; - if (file_version >= LLAMA_FILE_VERSION_GGMF_V1) - { + if (file_version >= LLAMA_FILE_VERSION_GGMF_V1) { file.read_raw(&score, sizeof(score)); } vocab.token_to_id[word] = i; - auto &tok_score = vocab.id_to_token[i]; + auto & tok_score = vocab.id_to_token[i]; tok_score.tok = std::move(word); tok_score.score = score; } @@ -524,8 +466,7 @@ struct llama_file_loader tensor.ne.resize(n_dims); file.read_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * n_dims); std::string name = file.read_string(name_len); - if (n_dims < 1 || n_dims > 2) - { + if (n_dims < 1 || n_dims > 2) { throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); } switch (tensor.type) { @@ -612,105 +553,10 @@ struct llama_file_saver { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: break; - default: - { - throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type)); - } - } - - if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) - { - // skip to the next multiple of 32 bytes - file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - } - shard.file_idx = file_idx; - shard.file_off = file.tell(); - - shard.calc_size(); - file.seek(shard.size, SEEK_CUR); - - auto it = tensors_map.name_to_idx.find(name); - size_t idx; - if (it != tensors_map.name_to_idx.end()) - { - idx = it->second; - } - else - { - tensors_map.tensors.emplace_back(name); - idx = tensors_map.tensors.size() - 1; - tensors_map.name_to_idx.emplace(name, idx); - } - tensors_map.tensors.at(idx).shards.push_back(shard); + default: LLAMA_ASSERT(false); } - } -}; - -struct llama_file_saver -{ - llama_file file; - llama_file_loader *any_file_loader; - llama_file_saver(const char *fname, llama_file_loader *any_file_loader, enum llama_ftype new_ftype) - : file(fname, "wb"), any_file_loader(any_file_loader) - { - fprintf(stderr, "llama.cpp: saving model to %s\n", fname); - write_magic(); - write_hparams(new_ftype); - write_vocab(); - } - void write_magic() - { - file.write_u32(LLAMA_FILE_MAGIC); // magic - file.write_u32(LLAMA_FILE_VERSION); // version - } - void write_hparams(enum llama_ftype new_ftype) - { - const llama_hparams &hparams = any_file_loader->hparams; - file.write_u32(hparams.n_vocab); - file.write_u32(hparams.n_embd); - file.write_u32(hparams.n_mult); - file.write_u32(hparams.n_head); - file.write_u32(hparams.n_layer); - file.write_u32(hparams.n_rot); - file.write_u32(new_ftype); - } - void write_vocab() - { - if (any_file_loader->file_version == LLAMA_FILE_VERSION_GGML) - { - fprintf(stderr, "llama.cpp: WARNING: input is an old file that doesn't have scores; will add dummy scores\n"); - } - uint32_t n_vocab = any_file_loader->hparams.n_vocab; - for (uint32_t i = 0; i < n_vocab; i++) - { - const auto &token_score = any_file_loader->vocab.id_to_token.at(i); - file.write_u32((uint32_t)token_score.tok.size()); - file.write_raw(token_score.tok.data(), token_score.tok.size()); - file.write_raw(&token_score.score, sizeof(token_score.score)); - } - } - void write_tensor(llama_load_tensor &tensor, enum ggml_type new_type, const void *new_data, size_t new_size) - { - switch (new_type) - { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - break; - default: - LLAMA_ASSERT(false); - } - file.write_u32((uint32_t)tensor.ne.size()); - file.write_u32((uint32_t)tensor.name.size()); + file.write_u32((uint32_t) tensor.ne.size()); + file.write_u32((uint32_t) tensor.name.size()); file.write_u32(new_type); file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size()); file.write_raw(tensor.name.data(), tensor.name.size()); @@ -725,7 +571,7 @@ struct llama_model_loader { llama_load_tensors_map tensors_map; bool use_mmap; size_t num_ggml_tensors_created = 0; - struct ggml_context *ggml_ctx = NULL; + struct ggml_context * ggml_ctx = NULL; std::unique_ptr mapping; llama_model_loader(const std::string & fname_base, bool use_mmap) { @@ -736,54 +582,43 @@ struct llama_model_loader { this->use_mmap = use_mmap; } - void calc_sizes(size_t *ctx_size_p, size_t *mmapped_size_p) const - { + void calc_sizes(size_t * ctx_size_p, size_t * mmapped_size_p) const { *ctx_size_p = *mmapped_size_p = 0; - for (const llama_load_tensor < : tensors_map.tensors) - { + for (const llama_load_tensor & lt : tensors_map.tensors) { *ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size; } } - struct ggml_tensor *get_tensor(const std::string &name, const std::vector &ne, ggml_backend backend) - { + struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne, ggml_backend backend) { auto it = tensors_map.name_to_idx.find(name); - if (it == tensors_map.name_to_idx.end()) - { + if (it == tensors_map.name_to_idx.end()) { throw std::runtime_error(std::runtime_error(format("llama.cpp: tensor '%s' is missing from model", name.c_str()))); } - llama_load_tensor < = tensors_map.tensors.at(it->second); - if (lt.ne != ne) - { + llama_load_tensor & lt = tensors_map.tensors.at(it->second); + if (lt.ne != ne) { throw std::runtime_error(format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", - name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str())); + name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str())); } return get_tensor_for(lt, backend); } - struct ggml_tensor *get_tensor_for(llama_load_tensor <, ggml_backend backend) - { - struct ggml_tensor *tensor; - if (backend != GGML_BACKEND_CPU) - { + struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) { + struct ggml_tensor * tensor; + if (backend != GGML_BACKEND_CPU) { ggml_set_no_alloc(ggml_ctx, true); } - if (lt.ne.size() == 2) - { + if (lt.ne.size() == 2) { tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1)); - } - else - { + } else { LLAMA_ASSERT(lt.ne.size() == 1); tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0)); } ggml_set_name(tensor, lt.name.c_str()); LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor - if (backend != GGML_BACKEND_CPU) - { + if (backend != GGML_BACKEND_CPU) { ggml_set_no_alloc(ggml_ctx, use_mmap); } tensor->backend = backend; @@ -792,24 +627,19 @@ struct llama_model_loader { return tensor; } - void done_getting_tensors() const - { - if (num_ggml_tensors_created != tensors_map.tensors.size()) - { + void done_getting_tensors() const { + if (num_ggml_tensors_created != tensors_map.tensors.size()) { throw std::runtime_error(std::string("llama.cpp: file contained more tensors than expected")); } } - void load_all_data(llama_progress_callback progress_callback, void *progress_callback_user_data, llama_mlock *lmlock) - { + void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { size_t data_size = 0; size_t prefetch_size = 0; size_t lock_size = 0; - for (const llama_load_tensor < : tensors_map.tensors) - { + for (const llama_load_tensor & lt : tensors_map.tensors) { data_size += lt.size; - if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) - { + if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) { prefetch_size += lt.size; } } @@ -822,54 +652,47 @@ struct llama_model_loader { } size_t done_size = 0; - for (llama_load_tensor < : tensors_map.tensors) - { - if (progress_callback) - { - progress_callback((float)done_size / data_size, progress_callback_user_data); + for (llama_load_tensor & lt : tensors_map.tensors) { + if (progress_callback) { + progress_callback((float) done_size / data_size, progress_callback_user_data); } LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already - lt.data = (uint8_t *)lt.ggml_tensor->data; + lt.data = (uint8_t *) lt.ggml_tensor->data; // allocate temp buffer if not using mmap - if (!use_mmap && lt.data == NULL) - { + if (!use_mmap && lt.data == NULL) { GGML_ASSERT(lt.ggml_tensor->backend != GGML_BACKEND_CPU); - lt.data = (uint8_t *)malloc(ggml_nbytes(lt.ggml_tensor)); + lt.data = (uint8_t*)malloc(ggml_nbytes(lt.ggml_tensor)); } load_data_for(lt); - switch (lt.ggml_tensor->backend) - { - case GGML_BACKEND_CPU: - lt.ggml_tensor->data = lt.data; - if (use_mmap && lmlock) - { - lock_size += lt.size; - lmlock->grow_to(lock_size); - } - break; + switch(lt.ggml_tensor->backend) { + case GGML_BACKEND_CPU: + lt.ggml_tensor->data = lt.data; + if (use_mmap && lmlock) { + lock_size += lt.size; + lmlock->grow_to(lock_size); + } + break; #if defined(GGML_USE_CUBLAS) - case GGML_BACKEND_GPU: - case GGML_BACKEND_GPU_SPLIT: - ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); - if (!use_mmap) - { - free(lt.data); - } - break; + case GGML_BACKEND_GPU: + case GGML_BACKEND_GPU_SPLIT: + ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); + if (!use_mmap) { + free(lt.data); + } + break; #elif defined(GGML_USE_CLBLAST) - case GGML_BACKEND_GPU: - ggml_cl_transform_tensor(lt.data, lt.ggml_tensor); - if (!use_mmap) - { - free(lt.data); - } - break; + case GGML_BACKEND_GPU: + ggml_cl_transform_tensor(lt.data, lt.ggml_tensor); + if (!use_mmap) { + free(lt.data); + } + break; #endif - default: - continue; + default: + continue; } done_size += lt.size; @@ -890,48 +713,46 @@ struct llama_model_loader { } } - static void print_checksum(llama_load_tensor <) - { + static void print_checksum(llama_load_tensor & lt) { uint32_t sum = 0; - for (size_t i = 0; i < lt.size; i++) - { + for (size_t i = 0; i < lt.size; i++) { uint8_t byte = lt.data[i]; sum = byte + (sum << 6) + (sum << 16) - sum; // sdbm hash } fprintf(stderr, "%s checksum: %#08x (%s, size %zu)\n", lt.name.c_str(), sum, llama_format_tensor_shape(lt.ne).c_str(), lt.size); } + }; + // // kv cache // static bool kv_cache_init( - const struct llama_hparams &hparams, - struct llama_kv_cache &cache, - ggml_type wtype, - int n_ctx, - int n_gpu_layers) -{ - const int n_embd = hparams.n_embd; + const struct llama_hparams & hparams, + struct llama_kv_cache & cache, + ggml_type wtype, + int n_ctx, + int n_gpu_layers) { + const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int64_t n_mem = n_layer * n_ctx; - const int64_t n_elements = n_embd * n_mem; + const int64_t n_mem = n_layer*n_ctx; + const int64_t n_elements = n_embd*n_mem; - cache.buf.resize(2u * n_elements * ggml_type_size(wtype) + 2u * MB); + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); cache.n = 0; struct ggml_init_params params; - params.mem_size = cache.buf.size; + params.mem_size = cache.buf.size; params.mem_buffer = cache.buf.addr; - params.no_alloc = false; + params.no_alloc = false; cache.ctx = ggml_init(params); - if (!cache.ctx) - { + if (!cache.ctx) { fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); return false; } @@ -941,14 +762,12 @@ static bool kv_cache_init( ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.v, "cache_v"); - (void)n_gpu_layers; + (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer + 1) - { + if (n_gpu_layers > n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); } - if (n_gpu_layers > n_layer + 2) - { + if (n_gpu_layers > n_layer + 2) { ggml_cuda_assign_buffers_no_scratch(cache.k); } #endif // GGML_USE_CUBLAS @@ -956,8 +775,7 @@ static bool kv_cache_init( return true; } -struct llama_context_params llama_context_default_params() -{ +struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, @@ -979,25 +797,22 @@ struct llama_context_params llama_context_default_params() return result; } -struct llama_model_quantize_params llama_model_quantize_default_params() -{ +struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { - /*.nthread =*/0, - /*.ftype =*/LLAMA_FTYPE_MOSTLY_Q5_1, - /*.allow_requantize =*/false, - /*.quantize_output_tensor =*/true, + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, }; return result; } -bool llama_mmap_supported() -{ +bool llama_mmap_supported() { return llama_mmap::SUPPORTED; } -bool llama_mlock_supported() -{ +bool llama_mlock_supported() { return llama_mlock::SUPPORTED; } @@ -1006,8 +821,8 @@ void llama_init_backend(bool numa) { // needed to initialize f16 tables { - struct ggml_init_params params = {0, NULL, false}; - struct ggml_context *ctx = ggml_init(params); + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } @@ -1016,8 +831,7 @@ void llama_init_backend(bool numa) { } } -int64_t llama_time_us() -{ +int64_t llama_time_us() { return ggml_time_us(); } @@ -1025,105 +839,70 @@ int64_t llama_time_us() // model loading // -static const char *llama_file_version_name(llama_file_version version) -{ - switch (version) - { - case LLAMA_FILE_VERSION_GGML: - return "'ggml' (old version with low tokenizer quality and no mmap support)"; - case LLAMA_FILE_VERSION_GGMF_V1: - return "ggmf v1 (old version with no mmap support)"; - case LLAMA_FILE_VERSION_GGJT_V1: - return "ggjt v1 (pre #1405)"; - case LLAMA_FILE_VERSION_GGJT_V2: - return "ggjt v2 (pre #1508)"; - case LLAMA_FILE_VERSION_GGJT_V3: - return "ggjt v3 (latest)"; +static const char *llama_file_version_name(llama_file_version version) { + switch (version) { + case LLAMA_FILE_VERSION_GGML: return "'ggml' (old version with low tokenizer quality and no mmap support)"; + case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; + case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; + case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; + case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; } return "unknown"; } -static const char *llama_ftype_name(enum llama_ftype ftype) -{ - switch (ftype) - { - case LLAMA_FTYPE_ALL_F32: - return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: - return "mostly F16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: - return "mostly Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: - return "mostly Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: - return "mostly Q4_1, some F16"; - case LLAMA_FTYPE_MOSTLY_Q5_0: - return "mostly Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: - return "mostly Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: - return "mostly Q8_0"; - // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: - return "mostly Q2_K"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: - return "mostly Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: - return "mostly Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: - return "mostly Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: - return "mostly Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: - return "mostly Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: - return "mostly Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: - return "mostly Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: - return "mostly Q6_K"; - default: - return "unknown, may not work"; - } -} - -static const char *llama_model_type_name(e_model type) -{ - switch (type) - { - case MODEL_3B: - return "3B"; - case MODEL_7B: - return "7B"; - case MODEL_13B: - return "13B"; - case MODEL_30B: - return "30B"; - case MODEL_65B: - return "65B"; - default: - LLAMA_ASSERT(false); +static const char *llama_ftype_name(enum llama_ftype ftype) { + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: + return "mostly Q4_1, some F16"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "mostly Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "mostly Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "mostly Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; + default: return "unknown, may not work"; + } +} + +static const char *llama_model_type_name(e_model type) { + switch (type) { + case MODEL_3B: return "3B"; + case MODEL_7B: return "7B"; + case MODEL_13B: return "13B"; + case MODEL_30B: return "30B"; + case MODEL_65B: return "65B"; + default: LLAMA_ASSERT(false); } } static void llama_model_load_internal( - const std::string &fname, - llama_model &model, - llama_vocab &vocab, - int n_ctx, - int n_batch, - int n_gpu_layers, - int main_gpu, - const float *tensor_split, - bool low_vram, - ggml_type memory_type, - bool use_mmap, - bool use_mlock, - bool vocab_only, - llama_progress_callback progress_callback, - void *progress_callback_user_data) -{ + const std::string & fname, + llama_model & model, + llama_vocab & vocab, + int n_ctx, + int n_batch, + int n_gpu_layers, + int main_gpu, + const float * tensor_split, + bool low_vram, + ggml_type memory_type, + bool use_mmap, + bool use_mlock, + bool vocab_only, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { model.t_start_us = ggml_time_us(); @@ -1136,118 +915,98 @@ static void llama_model_load_internal( auto & hparams = model.hparams; { - switch (hparams.n_layer) - { - case 26: - model.type = e_model::MODEL_3B; - break; - case 32: - model.type = e_model::MODEL_7B; - break; - case 40: - model.type = e_model::MODEL_13B; - break; - case 60: - model.type = e_model::MODEL_30B; - break; - case 80: - model.type = e_model::MODEL_65B; - break; - default: - { - if (hparams.n_layer < 32) - { - model.type = e_model::MODEL_7B; - } - } - break; + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = e_model::MODEL_65B; break; + default: + { + if (hparams.n_layer < 32) { + model.type = e_model::MODEL_7B; + } + } break; } hparams.n_ctx = n_ctx; } - const uint32_t n_ff = ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; + const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; { - fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); - fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); - fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); - fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); - fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); - fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); + fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); + fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); + fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } - if (file_version < LLAMA_FILE_VERSION_GGJT_V2) - { - if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && - hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && - hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) - { + if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { + if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && + hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && + hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { throw std::runtime_error(format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)")); } } - if (file_version < LLAMA_FILE_VERSION_GGJT_V3) - { + if (file_version < LLAMA_FILE_VERSION_GGJT_V3) { if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 || - hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) - { + hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) { throw std::runtime_error(format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)")); } } - if (vocab_only) - { + if (vocab_only) { return; } - auto &ctx = model.ctx; + auto & ctx = model.ctx; size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); // create the ggml context { model.buf.resize(ctx_size); - if (use_mlock) - { + if (use_mlock) { model.mlock_buf.init(model.buf.addr); model.mlock_buf.grow_to(model.buf.size); } struct ggml_init_params params = { - /*.mem_size =*/model.buf.size, - /*.mem_buffer =*/model.buf.addr, - /*.no_alloc =*/ml->use_mmap, + /*.mem_size =*/ model.buf.size, + /*.mem_buffer =*/ model.buf.addr, + /*.no_alloc =*/ ml->use_mmap, }; model.ctx = ggml_init(params); - if (!model.ctx) - { + if (!model.ctx) { throw std::runtime_error(format("ggml_init() failed")); } } - (void)main_gpu; + (void) main_gpu; #if defined(GGML_USE_CUBLAS) fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); -#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT #elif defined(GGML_USE_CLBLAST) fprintf(stderr, "%s: using OpenCL for GPU acceleration\n", __func__); -#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU #else -#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU #endif @@ -1255,7 +1014,7 @@ static void llama_model_load_internal( size_t vram_weights = 0; size_t vram_scratch = 0; { - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; @@ -1267,32 +1026,27 @@ static void llama_model_load_internal( { ggml_backend backend_norm; ggml_backend backend_output; - if (n_gpu_layers > int(n_layer)) - { // NOLINT + if (n_gpu_layers > int(n_layer)) { // NOLINT // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int)n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } - else - { + } else { backend_norm = GGML_BACKEND_CPU; backend_output = GGML_BACKEND_CPU; } - model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm); + model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm); model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); - if (backend_norm == GGML_BACKEND_GPU) - { + if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.norm); } - if (backend_output == GGML_BACKEND_GPU_SPLIT) - { + if (backend_output == GGML_BACKEND_GPU_SPLIT) { vram_weights += ggml_nbytes(model.output); } } @@ -1300,12 +1054,11 @@ static void llama_model_load_internal( const int i_gpu_start = n_layer - n_gpu_layers; model.layers.resize(n_layer); - for (uint32_t i = 0; i < n_layer; ++i) - { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - auto &layer = model.layers[i]; + auto & layer = model.layers[i]; std::string layers_i = "layers." + std::to_string(i); @@ -1318,16 +1071,15 @@ static void llama_model_load_internal( layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {n_ff, n_embd}, backend_split); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); + layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); - if (backend == GGML_BACKEND_GPU) - { + if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); } } } @@ -1344,29 +1096,25 @@ static void llama_model_load_internal( mmapped_size - vram_weights + // weights in VRAM not in memory MEM_REQ_SCRATCH0().at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at(model.type); + MEM_REQ_EVAL().at (model.type); // this is the memory required by one llama_state const size_t mem_required_state = - scale * MEM_REQ_KV_SELF().at(model.type); + scale*MEM_REQ_KV_SELF().at(model.type); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - (void)vram_scratch; - (void)n_batch; + (void) vram_scratch; + (void) n_batch; #ifdef GGML_USE_CUBLAS - if (low_vram) - { + if (low_vram) { fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); ggml_cuda_set_scratch_size(0); // disable scratch - } - else - { + } else { vram_scratch = n_batch * MB; ggml_cuda_set_scratch_size(vram_scratch); - if (n_gpu_layers > 0) - { + if (n_gpu_layers > 0) { fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n", __func__, vram_scratch / MB); } @@ -1376,31 +1124,22 @@ static void llama_model_load_internal( const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); fprintf(stderr, "%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); - if (n_gpu_layers > (int)hparams.n_layer) - { + if (n_gpu_layers > (int) hparams.n_layer) { fprintf(stderr, "%s: offloading non-repeating layers to GPU\n", __func__); } size_t vram_kv_cache = 0; - if (n_gpu_layers > (int)hparams.n_layer + 1) - { - if (low_vram) - { + if (n_gpu_layers > (int) hparams.n_layer + 1) { + if (low_vram) { fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); - } - else - { + } else { fprintf(stderr, "%s: offloading v cache to GPU\n", __func__); vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; } } - if (n_gpu_layers > (int)hparams.n_layer + 2) - { - if (low_vram) - { + if (n_gpu_layers > (int) hparams.n_layer + 2) { + if (low_vram) { fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); - } - else - { + } else { fprintf(stderr, "%s: offloading k cache to GPU\n", __func__); vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; } @@ -1411,17 +1150,16 @@ static void llama_model_load_internal( fprintf(stderr, "%s: total VRAM used: %zu MB\n", __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up #else - (void)n_gpu_layers; + (void) n_gpu_layers; #endif } // populate `tensors_by_name` - for (llama_load_tensor < : ml->tensors_map.tensors) - { + for (llama_load_tensor & lt : ml->tensors_map.tensors) { model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); } - (void)tensor_split; + (void) tensor_split; #if defined(GGML_USE_CUBLAS) { ggml_cuda_set_tensor_split(tensor_split); @@ -1430,8 +1168,7 @@ static void llama_model_load_internal( ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); - if (progress_callback) - { + if (progress_callback) { progress_callback(1.0f, progress_callback_user_data); } @@ -1443,30 +1180,26 @@ static void llama_model_load_internal( } static bool llama_model_load( - const std::string &fname, - llama_model &model, - llama_vocab &vocab, - int n_ctx, - int n_batch, - int n_gpu_layers, - int main_gpu, - float *tensor_split, - bool low_vram, - ggml_type memory_type, - bool use_mmap, - bool use_mlock, - bool vocab_only, - llama_progress_callback progress_callback, - void *progress_callback_user_data) -{ - try - { + const std::string & fname, + llama_model & model, + llama_vocab & vocab, + int n_ctx, + int n_batch, + int n_gpu_layers, + int main_gpu, + float * tensor_split, + bool low_vram, + ggml_type memory_type, + bool use_mmap, + bool use_mlock, + bool vocab_only, + llama_progress_callback progress_callback, + void *progress_callback_user_data) { + try { llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; - } - catch (const std::exception &err) - { + } catch (const std::exception & err) { fprintf(stderr, "error loading model: %s\n", err.what()); return false; } @@ -1502,31 +1235,31 @@ static bool llama_eval_internal( const int N = n_tokens; - const auto &model = lctx.model; - const auto &hparams = model.hparams; + const auto & model = lctx.model; + const auto & hparams = model.hparams; - const auto &kv_self = lctx.kv_self; + const auto & kv_self = lctx.kv_self; LLAMA_ASSERT(!!kv_self.ctx); - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - const int n_rot = hparams.n_embd / hparams.n_head; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_embd/hparams.n_head; const int n_gpu_layers = model.n_gpu_layers; - auto &mem_per_token = lctx.mem_per_token; - auto &buf_compute = lctx.buf_compute; + auto & mem_per_token = lctx.mem_per_token; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { - /*.mem_size =*/buf_compute.size, - /*.mem_buffer =*/buf_compute.addr, - /*.no_alloc =*/false, + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.addr, + /*.no_alloc =*/ false, }; - struct ggml_context *ctx0 = ggml_init(params); + struct ggml_context * ctx0 = ggml_init(params); // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance @@ -1547,7 +1280,7 @@ static bool llama_eval_internal( } const int i_gpu_start = n_layer - n_gpu_layers; - (void)i_gpu_start; + (void) i_gpu_start; // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded @@ -1556,35 +1289,30 @@ static bool llama_eval_internal( // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; + offload_func_t offload_func_v = llama_nop; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) - { - offload_func_nr = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 1) - { - offload_func_v = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 2) - { - offload_func_kq = ggml_cuda_assign_buffers; - } + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers; + } #endif // GGML_USE_CUBLAS - for (int il = 0; il < n_layer; ++il) - { + for (int il = 0; il < n_layer; ++il) { offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) - { + if (il >= i_gpu_start) { offload_func = ggml_cuda_assign_buffers; } #endif // GGML_USE_CUBLAS - struct ggml_tensor *inpSA = inpL; + struct ggml_tensor * inpSA = inpL; lctx.use_buf(ctx0, 0); @@ -1603,11 +1331,11 @@ static bool llama_eval_internal( // self-attention { // compute Q and K and RoPE them - struct ggml_tensor *tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); offload_func_kq(tmpk); ggml_set_name(tmpk, "tmpk"); - struct ggml_tensor *tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); @@ -1623,21 +1351,21 @@ static bool llama_eval_internal( { // compute the transposed [N, n_embd] V matrix - struct ggml_tensor *tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor *Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor *k = ggml_view_1d(ctx0, kv_self.k, N * n_embd, (ggml_element_size(kv_self.k) * n_embd) * (il * n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor *v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - (n_ctx)*ggml_element_size(kv_self.v), - (il * n_ctx) * ggml_element_size(kv_self.v) * n_embd + n_past * ggml_element_size(kv_self.v)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -1646,91 +1374,91 @@ static bool llama_eval_internal( ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); } - struct ggml_tensor *Q = + struct ggml_tensor * Q = ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); + Qcur, + 0, 2, 1, 3); offload_func_kq(Q); ggml_set_name(Q, "Q"); - struct ggml_tensor *K = + struct ggml_tensor * K = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_embd, il * n_ctx * ggml_element_size(kv_self.k) * n_embd), - n_embd / n_head, n_head, n_past + N), - 0, 2, 1, 3); + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); offload_func_kq(K); ggml_set_name(K, "K"); // K * Q - struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ggml_tensor *KQ_scale = ggml_new_f32(ctx0, 1.0f / sqrtf(float(n_embd) / n_head)); + struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor *KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor *KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor *KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads - struct ggml_tensor *V = + struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd / n_head, n_head, - n_ctx * ggml_element_size(kv_self.v), - n_ctx * ggml_element_size(kv_self.v) * n_embd / n_head, - il * n_ctx * ggml_element_size(kv_self.v) * n_embd); + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, + il*n_ctx*ggml_element_size(kv_self.v)*n_embd); offload_func_v(V); ggml_set_name(V, "V"); #if 1 - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); offload_func_v(KQV); ggml_set_name(KQV, "KQV"); #else // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor *V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd / n_head, n_head)); - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head)); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); // cur = KQV_merged.contiguous().view(n_embd, N) cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); // projection (no bias) cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); + model.layers[il].wo, + cur); offload_func(cur); ggml_set_name(cur, "result_wo"); } lctx.use_buf(ctx0, 1); - struct ggml_tensor *inpFF = ggml_add(ctx0, cur, inpSA); + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); offload_func(inpFF); ggml_set_name(inpFF, "inpFF"); @@ -1748,15 +1476,15 @@ static bool llama_eval_internal( ggml_set_name(cur, "ffn_norm"); } - struct ggml_tensor *tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); offload_func(tmp); ggml_set_name(tmp, "result_w3"); cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); + model.layers[il].w1, + cur); offload_func(cur); ggml_set_name(cur, "result_w1"); @@ -1770,8 +1498,8 @@ static bool llama_eval_internal( ggml_set_name(cur, "silu_x_result_w3"); cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); + model.layers[il].w2, + cur); offload_func(cur); ggml_set_name(cur, "result_w2"); } @@ -1782,12 +1510,14 @@ static bool llama_eval_internal( // input for next layer inpL = cur; + } lctx.use_buf(ctx0, 0); // used at the end to optionally extract the embeddings - struct ggml_tensor *embeddings = NULL; + struct ggml_tensor * embeddings = NULL; + // norm { @@ -1803,6 +1533,7 @@ static bool llama_eval_internal( embeddings = cur; } + // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); @@ -1810,19 +1541,16 @@ static bool llama_eval_internal( lctx.use_buf(ctx0, -1); // logits -> probs - // cur = ggml_soft_max_inplace(ctx0, cur); + //cur = ggml_soft_max_inplace(ctx0, cur); // run the computation ggml_build_forward_expand(&gf, cur); #ifdef GGML_USE_METAL - if (lctx.ctx_metal && N == 1) - { + if (lctx.ctx_metal && N == 1) { ggml_metal_graph_compute(lctx.ctx_metal, &gf); - ggml_metal_get_tensor(lctx.ctx_metal, cur); - } - else - { + ggml_metal_get_tensor (lctx.ctx_metal, cur); + } else { // IMPORTANT: // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX @@ -1833,8 +1561,7 @@ static bool llama_eval_internal( // // TODO: avoid these syncs via shared memory (ref #1696) // - if (lctx.ctx_metal) - { + if (lctx.ctx_metal) { // We need to sync the GPU KV cache with the CPU KV cache ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k); ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); @@ -1846,8 +1573,7 @@ static bool llama_eval_internal( ggml_graph_compute(ctx0, &gf); #endif - if (cgraph_fname) - { + if (cgraph_fname) { ggml_graph_export(&gf, cgraph_fname); } @@ -1858,45 +1584,40 @@ static bool llama_eval_internal( #endif // plot the computation graph in dot format (for debugging purposes) - // if (n_past%100 == 0) { + //if (n_past%100 == 0) { // ggml_graph_dump_dot(&gf, NULL, "llama.dot"); //} - // embd_w.resize(n_vocab*N); - // memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); // update kv token count lctx.kv_self.n = n_past + N; // extract logits { - auto &logits_out = lctx.logits; + auto & logits_out = lctx.logits; - if (lctx.logits_all) - { + if (lctx.logits_all) { logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *)ggml_get_data(cur), sizeof(float) * n_vocab * N); - } - else - { + memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); + } else { // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *)ggml_get_data(cur) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } } // extract embeddings - if (!lctx.embedding.empty()) - { - auto &embedding_out = lctx.embedding; + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *)ggml_get_data(embeddings) + (n_embd * (N - 1)), sizeof(float) * n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } - if (mem_per_token == 0) - { - mem_per_token = ggml_used_mem(ctx0) / N; + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; } #if 0 @@ -1909,13 +1630,11 @@ static bool llama_eval_internal( ggml_free(ctx0); // measure the performance only for the single-token evals - if (N == 1) - { + if (N == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.n_eval++; } - else if (N > 1) - { + else if (N > 1) { lctx.t_p_eval_us += ggml_time_us() - t_start_us; lctx.n_p_eval += N; } @@ -1927,30 +1646,25 @@ static bool llama_eval_internal( // tokenizer // -static size_t utf8_len(char src) -{ - const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(src) >> 4; return lookup[highbits]; } -struct llama_sp_symbol -{ +struct llama_sp_symbol { using index = int; index prev; index next; - const char *text; + const char * text; size_t n; }; static_assert(std::is_trivially_copyable::value, "llama_sp_symbol is not trivially copyable"); -struct llama_sp_bigram -{ - struct comparator - { - bool operator()(llama_sp_bigram &l, llama_sp_bigram &r) - { +struct llama_sp_bigram { + struct comparator { + bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { return (l.score < r.score) || (l.score == r.score && l.left > r.left); } }; @@ -1964,17 +1678,14 @@ struct llama_sp_bigram // original implementation: // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 -struct llama_tokenizer -{ - llama_tokenizer(const llama_vocab &vocab) : vocab_(vocab) {} +struct llama_tokenizer { + llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} - void tokenize(const std::string &text, std::vector &output) - { + void tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; - while (offs < text.size()) - { + while (offs < text.size()) { llama_sp_symbol sym; size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); sym.text = text.c_str() + offs; @@ -1987,24 +1698,21 @@ struct llama_tokenizer } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols_.size(); ++i) - { + for (size_t i = 1; i < symbols_.size(); ++i) { try_add_bigram(i - 1, i); } // keep substituting the highest frequency pairs for as long as we can. - while (!work_queue_.empty()) - { + while (!work_queue_.empty()) { auto bigram = work_queue_.top(); work_queue_.pop(); - auto &left_sym = symbols_[bigram.left]; - auto &right_sym = symbols_[bigram.right]; + auto & left_sym = symbols_[bigram.left]; + auto & right_sym = symbols_[bigram.right]; // if one of the symbols already got merged, skip it. if (left_sym.n == 0 || right_sym.n == 0 || - left_sym.n + right_sym.n != bigram.size) - { + left_sym.n + right_sym.n != bigram.size) { continue; } @@ -2012,12 +1720,11 @@ struct llama_tokenizer left_sym.n += right_sym.n; right_sym.n = 0; - // printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); + //printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); // remove the right sym from the chain left_sym.next = right_sym.next; - if (right_sym.next >= 0) - { + if (right_sym.next >= 0) { symbols_[right_sym.next].prev = bigram.left; } @@ -2026,45 +1733,36 @@ struct llama_tokenizer try_add_bigram(bigram.left, left_sym.next); } - for (int i = 0; i != -1; i = symbols_[i].next) - { - auto &symbol = symbols_[i]; + for (int i = 0; i != -1; i = symbols_[i].next) { + auto & symbol = symbols_[i]; auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); - if (token == vocab_.token_to_id.end()) - { + if (token == vocab_.token_to_id.end()) { // output any symbols that did not form tokens as bytes. - for (int j = 0; j < (int)symbol.n; ++j) - { + for (int j = 0; j < (int) symbol.n; ++j) { llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; output.push_back(token_id); } - } - else - { + } else { output.push_back((*token).second); } } } private: - void try_add_bigram(int left, int right) - { - if (left == -1 || right == -1) - { + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { return; } const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n); auto token = vocab_.token_to_id.find(text); - if (token == vocab_.token_to_id.end()) - { + if (token == vocab_.token_to_id.end()) { return; } - if (static_cast((*token).second) >= vocab_.id_to_token.size()) - { + if (static_cast((*token).second) >= vocab_.id_to_token.size()) { return; } @@ -2078,23 +1776,20 @@ struct llama_tokenizer work_queue_.push(bigram); } - const llama_vocab &vocab_; + const llama_vocab & vocab_; std::vector symbols_; llama_sp_bigram::queue work_queue_; }; -static std::vector llama_tokenize(const llama_vocab &vocab, const std::string &text, bool bos) -{ +static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { llama_tokenizer tokenizer(vocab); std::vector output; - if (text.empty()) - { + if (text.empty()) { return output; } - if (bos) - { + if (bos) { output.push_back(llama_token_bos()); } @@ -2106,75 +1801,62 @@ static std::vector llama_tokenize(const llama_vocab &vocab, con // sampling // -void llama_sample_softmax(struct llama_context *ctx, llama_token_data_array *candidates) -{ +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { assert(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); // Sort the logits in descending order - if (!candidates->sorted) - { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data &a, const llama_token_data &b) - { return a.logit > b.logit; }); + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); candidates->sorted = true; } float max_l = candidates->data[0].logit; float cum_sum = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { float p = expf(candidates->data[i].logit - max_l); candidates->data[i].p = p; cum_sum += p; } - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum; } - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_top_k(struct llama_context *ctx, llama_token_data_array *candidates, int k, size_t min_keep) -{ +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) { const int64_t t_start_sample_us = ggml_time_us(); - k = std::max(k, (int)min_keep); - k = std::min(k, (int)candidates->size); + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates->size); // Sort scores in descending order - if (!candidates->sorted) - { - auto comp = [](const llama_token_data &a, const llama_token_data &b) - { + if (!candidates->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - if (k == (int)candidates->size) - { + if (k == (int) candidates->size) { std::sort(candidates->data, candidates->data + candidates->size, comp); - } - else - { + } else { std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } candidates->sorted = true; } candidates->size = k; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_top_p(struct llama_context *ctx, llama_token_data_array *candidates, float p, size_t min_keep) -{ - if (p >= 1.0f) - { +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p >= 1.0f) { return; } @@ -2186,14 +1868,12 @@ void llama_sample_top_p(struct llama_context *ctx, llama_token_data_array *candi float cum_sum = 0.0f; size_t last_idx = candidates->size; - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { cum_sum += candidates->data[i].p; // Check if the running sum is at least p or if we have kept at least min_keep tokens // we set the last index to i+1 to indicate that the current iterate should be included in the set - if (cum_sum >= p && i + 1 >= min_keep) - { + if (cum_sum >= p && i + 1 >= min_keep) { last_idx = i + 1; break; } @@ -2202,16 +1882,13 @@ void llama_sample_top_p(struct llama_context *ctx, llama_token_data_array *candi // Resize the output vector to keep only the top-p tokens candidates->size = last_idx; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_tail_free(struct llama_context *ctx, llama_token_data_array *candidates, float z, size_t min_keep) -{ - if (z >= 1.0f || candidates->size <= 2) - { +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { return; } @@ -2223,37 +1900,31 @@ void llama_sample_tail_free(struct llama_context *ctx, llama_token_data_array *c std::vector first_derivatives(candidates->size - 1); std::vector second_derivatives(candidates->size - 2); - for (size_t i = 0; i < first_derivatives.size(); ++i) - { + for (size_t i = 0; i < first_derivatives.size(); ++i) { first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; } - for (size_t i = 0; i < second_derivatives.size(); ++i) - { + for (size_t i = 0; i < second_derivatives.size(); ++i) { second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; } // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) - { + for (size_t i = 0; i < second_derivatives.size(); ++i) { second_derivatives[i] = abs(second_derivatives[i]); } // Normalize the second derivatives float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - for (float &value : second_derivatives) - { + for (float & value : second_derivatives) { value /= second_derivatives_sum; } float cum_sum = 0.0f; size_t last_idx = candidates->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) - { + for (size_t i = 0; i < second_derivatives.size(); ++i) { cum_sum += second_derivatives[i]; // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > z && i >= min_keep) - { + if (cum_sum > z && i >= min_keep) { last_idx = i; break; } @@ -2262,18 +1933,16 @@ void llama_sample_tail_free(struct llama_context *ctx, llama_token_data_array *c // Resize the output vector to keep only the tokens above the tail location candidates->size = last_idx; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_typical(struct llama_context *ctx, llama_token_data_array *candidates, float p, size_t min_keep) -{ + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr - if (p >= 1.0f) - { + if (p >= 1.0f) { return; } @@ -2283,15 +1952,13 @@ void llama_sample_typical(struct llama_context *ctx, llama_token_data_array *can llama_sample_softmax(nullptr, candidates); float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { entropy += -candidates->data[i].p * logf(candidates->data[i].p); } // Compute the absolute difference between negative log probability and entropy for each candidate std::vector shifted_scores; - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); shifted_scores.push_back(shifted_score); } @@ -2300,21 +1967,20 @@ void llama_sample_typical(struct llama_context *ctx, llama_token_data_array *can std::vector indices(candidates->size); std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) - { return shifted_scores[a] < shifted_scores[b]; }); + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); // Compute the cumulative probabilities float cum_sum = 0.0f; size_t last_idx = indices.size(); - for (size_t i = 0; i < indices.size(); ++i) - { + for (size_t i = 0; i < indices.size(); ++i) { size_t idx = indices[i]; cum_sum += candidates->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep - 1) - { + if (cum_sum > p && i >= min_keep - 1) { last_idx = i + 1; break; } @@ -2322,8 +1988,7 @@ void llama_sample_typical(struct llama_context *ctx, llama_token_data_array *can // Resize the output vector to keep only the locally typical tokens std::vector new_candidates; - for (size_t i = 0; i < last_idx; ++i) - { + for (size_t i = 0; i < last_idx; ++i) { size_t idx = indices[i]; new_candidates.push_back(candidates->data[idx]); } @@ -2332,68 +1997,54 @@ void llama_sample_typical(struct llama_context *ctx, llama_token_data_array *can std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); candidates->size = new_candidates.size(); - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_temperature(struct llama_context *ctx, llama_token_data_array *candidates_p, float temp) -{ +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { const int64_t t_start_sample_us = ggml_time_us(); - for (size_t i = 0; i < candidates_p->size; ++i) - { + for (size_t i = 0; i < candidates_p->size; ++i) { candidates_p->data[i].logit /= temp; } - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_repetition_penalty(struct llama_context *ctx, llama_token_data_array *candidates, const llama_token *last_tokens, size_t last_tokens_size, float penalty) -{ - if (last_tokens_size == 0 || penalty == 1.0f) - { +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { + if (last_tokens_size == 0 || penalty == 1.0f) { return; } const int64_t t_start_sample_us = ggml_time_us(); - for (size_t i = 0; i < candidates->size; ++i) - { - const auto *token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); - if (token_iter == last_tokens + last_tokens_size) - { + for (size_t i = 0; i < candidates->size; ++i) { + const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); + if (token_iter == last_tokens + last_tokens_size) { continue; } // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) - { + if (candidates->data[i].logit <= 0) { candidates->data[i].logit *= penalty; - } - else - { + } else { candidates->data[i].logit /= penalty; } } candidates->sorted = false; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_frequency_and_presence_penalties(struct llama_context *ctx, llama_token_data_array *candidates, const llama_token *last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) -{ - if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) - { +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { + if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } @@ -2401,17 +2052,14 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context *ctx, ll // Create a frequency map to count occurrences of each token in last_tokens std::unordered_map token_count; - for (size_t i = 0; i < last_tokens_size; ++i) - { + for (size_t i = 0; i < last_tokens_size; ++i) { token_count[last_tokens_p[i]]++; } // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { auto token_iter = token_count.find(candidates->data[i].id); - if (token_iter == token_count.end()) - { + if (token_iter == token_count.end()) { continue; } @@ -2421,14 +2069,13 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context *ctx, ll candidates->sorted = false; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -llama_token llama_sample_token_mirostat(struct llama_context *ctx, llama_token_data_array *candidates, float tau, float eta, int m, float *mu) -{ + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); auto N = float(llama_n_vocab(ctx)); int64_t t_start_sample_us; @@ -2440,8 +2087,7 @@ llama_token llama_sample_token_mirostat(struct llama_context *ctx, llama_token_d float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) - { + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { float t_i = logf(float(i + 2) / float(i + 1)); float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); sum_ti_bi += t_i * b_i; @@ -2455,32 +2101,30 @@ llama_token llama_sample_token_mirostat(struct llama_context *ctx, llama_token_d // Sample the next word X using top-k sampling llama_sample_top_k(nullptr, candidates, int(k), 1); - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } llama_token X = llama_sample_token(ctx, candidates); t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data &candidate) - { return candidate.id == X; })); + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error *mu = *mu - eta * e; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->n_sample++; } return X; } -llama_token llama_sample_token_mirostat_v2(struct llama_context *ctx, llama_token_data_array *candidates, float tau, float eta, float *mu) -{ +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { assert(ctx); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -2488,11 +2132,11 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context *ctx, llama_toke llama_sample_softmax(ctx, candidates); // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data &candidate) - { return -log2f(candidate.p) > *mu; })); + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); - if (candidates->size == 0) - { + if (candidates->size == 0) { candidates->size = 1; } @@ -2500,61 +2144,57 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context *ctx, llama_toke llama_sample_softmax(ctx, candidates); // Sample the next word X from the remaining words - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } llama_token X = llama_sample_token(ctx, candidates); t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data &candidate) - { return candidate.id == X; })); + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error *mu = *mu - eta * e; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } return X; } -llama_token llama_sample_token_greedy(struct llama_context *ctx, llama_token_data_array *candidates) -{ +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { const int64_t t_start_sample_us = ggml_time_us(); // Find max element - auto *max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data &a, const llama_token_data &b) - { return a.logit < b.logit; }); + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); llama_token result = max_iter->id; - if (ctx) - { + if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->n_sample++; } return result; } -llama_token llama_sample_token(struct llama_context *ctx, llama_token_data_array *candidates) -{ +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); llama_sample_softmax(nullptr, candidates); std::vector probs; probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) - { + for (size_t i = 0; i < candidates->size; ++i) { probs.push_back(candidates->data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto &rng = ctx->rng; + auto & rng = ctx->rng; int idx = dist(rng); llama_token result = candidates->data[idx].id; @@ -2568,40 +2208,28 @@ llama_token llama_sample_token(struct llama_context *ctx, llama_token_data_array // quantization // -static void llama_convert_tensor_internal(const llama_load_tensor &tensor, llama_buffer &output, const int nelements, const int nthread) -{ - if (output.size < nelements * sizeof(float)) - { +static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llama_buffer & output, const int nelements, const int nthread) { + if (output.size < nelements * sizeof(float)) { output.resize(nelements * sizeof(float)); } - float *f32_output = (float *)output.addr; + float * f32_output = (float *) output.addr; quantize_fns_t qtype; - if (ggml_is_quantized(tensor.type)) - { + if (ggml_is_quantized(tensor.type)) { qtype = ggml_internal_get_quantize_fn(tensor.type); - if (qtype.dequantize_row_q == NULL) - { + if (qtype.dequantize_row_q == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type))); } - } - else if (tensor.type != GGML_TYPE_F16) - { + } else if (tensor.type != GGML_TYPE_F16) { throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor.type))); } - if (nthread < 2) - { - if (tensor.type == GGML_TYPE_F16) - { + if (nthread < 2) { + if (tensor.type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements); - } - else if (ggml_is_quantized(tensor.type)) - { + } else if (ggml_is_quantized(tensor.type)) { qtype.dequantize_row_q(tensor.data, f32_output, nelements); - } - else - { + } else { LLAMA_ASSERT(false); // unreachable } return; @@ -2616,20 +2244,15 @@ static void llama_convert_tensor_internal(const llama_load_tensor &tensor, llama auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count std::vector workers; - for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) - { + for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) { auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread - auto thr_elems = thr_blocks * block_size; // number of elements for this thread - auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread + auto thr_elems = thr_blocks * block_size; // number of elements for this thread + auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread - auto compute = [qtype](ggml_type typ, uint8_t *inbuf, float *outbuf, int nels) - { - if (typ == GGML_TYPE_F16) - { + auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { + if (typ == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); - } - else - { + } else { qtype.dequantize_row_q(inbuf, outbuf, nels); } }; @@ -2637,70 +2260,42 @@ static void llama_convert_tensor_internal(const llama_load_tensor &tensor, llama in_buff_offs += thr_block_bytes; out_buff_offs += thr_elems; } - for (auto &worker : workers) - { + for (auto & worker : workers) { worker.join(); } + } -static void llama_model_quantize_internal(const std::string &fname_inp, const std::string &fname_out, const llama_model_quantize_params *params) -{ +static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type quantized_type; llama_ftype ftype = params->ftype; int nthread = params->nthread; - switch (params->ftype) - { - case LLAMA_FTYPE_MOSTLY_Q4_0: - quantized_type = GGML_TYPE_Q4_0; - break; - case LLAMA_FTYPE_MOSTLY_Q4_1: - quantized_type = GGML_TYPE_Q4_1; - break; - case LLAMA_FTYPE_MOSTLY_Q5_0: - quantized_type = GGML_TYPE_Q5_0; - break; - case LLAMA_FTYPE_MOSTLY_Q5_1: - quantized_type = GGML_TYPE_Q5_1; - break; - case LLAMA_FTYPE_MOSTLY_Q8_0: - quantized_type = GGML_TYPE_Q8_0; - break; - case LLAMA_FTYPE_MOSTLY_F16: - quantized_type = GGML_TYPE_F16; - break; - case LLAMA_FTYPE_ALL_F32: - quantized_type = GGML_TYPE_F32; - break; + switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; + case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; #ifdef GGML_USE_K_QUANTS - // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: - quantized_type = GGML_TYPE_Q2_K; - break; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: - case LLAMA_FTYPE_MOSTLY_Q3_K_M: - case LLAMA_FTYPE_MOSTLY_Q3_K_L: - quantized_type = GGML_TYPE_Q3_K; - break; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: - case LLAMA_FTYPE_MOSTLY_Q4_K_M: - quantized_type = GGML_TYPE_Q4_K; - break; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: - case LLAMA_FTYPE_MOSTLY_Q5_K_M: - quantized_type = GGML_TYPE_Q5_K; - break; - case LLAMA_FTYPE_MOSTLY_Q6_K: - quantized_type = GGML_TYPE_Q6_K; - break; + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; #endif - default: - throw std::runtime_error(format("invalid output file type %d\n", ftype)); + default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } - if (nthread <= 0) - { + if (nthread <= 0) { nthread = std::thread::hardware_concurrency(); } @@ -2708,16 +2303,13 @@ static void llama_model_quantize_internal(const std::string &fname_inp, const st llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get(), params->ftype); #ifdef GGML_USE_K_QUANTS - int n_attention_wv = 0; + int n_attention_wv = 0; int n_feed_forward_w2 = 0; - for (auto &tensor : model_loader->tensors_map.tensors) - { - if (tensor.name.find("attention.wv.weight") != std::string::npos) - { + for (auto& tensor : model_loader->tensors_map.tensors) { + if (tensor.name.find("attention.wv.weight") != std::string::npos) { ++n_attention_wv; } - else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) - { + else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { ++n_feed_forward_w2; } } @@ -2738,8 +2330,7 @@ static void llama_model_quantize_internal(const std::string &fname_inp, const st }; size_t idx = 0; - for (llama_load_tensor &tensor : model_loader->tensors_map.tensors) - { + for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) { llama_buffer read_data; read_data.resize(tensor.size); tensor.data = read_data.addr; @@ -2759,91 +2350,67 @@ static void llama_model_quantize_internal(const std::string &fname_inp, const st quantize &= quantized_type != tensor.type; enum ggml_type new_type; - void *new_data; + void * new_data; size_t new_size; llama_buffer work; - if (!quantize) - { + if (!quantize) { new_type = tensor.type; new_data = tensor.data; new_size = tensor.size; - printf("size = %8.3f MB\n", tensor.size / 1024.0 / 1024.0); - } - else - { + printf("size = %8.3f MB\n", tensor.size/1024.0/1024.0); + } else { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K || - quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) - { + quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) { int nx = tensor.ne.at(0); int ny = tensor.ne.at(1); - if (nx % QK_K != 0 || ny % QK_K != 0) - { - fprintf(stderr, "\n\n========================= Tensor sizes %d x %d are not divisible by %d\n", nx, ny, QK_K); + if (nx % QK_K != 0 || ny % QK_K != 0) { + fprintf(stderr, "\n\n========================= Tensor sizes %d x %d are not divisible by %d\n",nx,ny,QK_K); fprintf(stderr, "This is required to be able to use k-quants for now!\n"); fprintf(stderr, "========================================================================================\n\n"); throw std::runtime_error("Unsupported tensor size encountered\n"); } } - if (tensor.name == "output.weight") - { + if (tensor.name == "output.weight") { int nx = tensor.ne.at(0); int ny = tensor.ne.at(1); - if (nx % QK_K == 0 && ny % QK_K == 0) - { + if (nx % QK_K == 0 && ny % QK_K == 0) { new_type = GGML_TYPE_Q6_K; } - } - else if (tensor.name.find("attention.wv.weight") != std::string::npos) - { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) - new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) - new_type = GGML_TYPE_Q5_K; + } else if (tensor.name.find("attention.wv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; - } - else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) - { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) - new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) - new_type = GGML_TYPE_Q5_K; + } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; - } - else if (tensor.name.find("attention.wo.weight") != std::string::npos) - { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) - new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) - new_type = GGML_TYPE_Q5_K; + } else if (tensor.name.find("attention.wo.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; } #endif - float *f32_data; + float * f32_data; size_t nelements = tensor.ne.at(0) * tensor.ne.at(1); llama_buffer f32_conv_buf; - if (tensor.type == GGML_TYPE_F32) - { - f32_data = (float *)tensor.data; - } - else if (ggml_is_quantized(tensor.type) && !params->allow_requantize) - { + if (tensor.type == GGML_TYPE_F32) { + f32_data = (float *) tensor.data; + } else if (ggml_is_quantized(tensor.type) && !params->allow_requantize) { throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor.type))); - } - else - { + } else { llama_convert_tensor_internal(tensor, f32_conv_buf, nelements, nthread); - f32_data = (float *)f32_conv_buf.addr; + f32_data = (float *) f32_conv_buf.addr; } printf("quantizing .. "); @@ -2854,31 +2421,22 @@ static void llama_model_quantize_internal(const std::string &fname_inp, const st std::vector hist_cur(1 << 4, 0); int chunk_size = 32 * 512; - const int nchunk = (nelements + chunk_size - 1) / chunk_size; + const int nchunk = (nelements + chunk_size - 1)/chunk_size; const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; - if (nthread_use < 2) - { + if (nthread_use < 2) { new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); - } - else - { + } else { size_t counter = 0; new_size = 0; - auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements, chunk_size]() - { + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements, chunk_size] () { std::vector local_hist; size_t local_size = 0; - while (true) - { + while (true) { std::unique_lock lock(mutex); - size_t first = counter; - counter += chunk_size; - if (first >= nelements) - { - if (!local_hist.empty()) - { - for (int j = 0; j < int(local_hist.size()); ++j) - { + size_t first = counter; counter += chunk_size; + if (first >= nelements) { + if (!local_hist.empty()) { + for (int j=0; j %8.2f MB | hist: ", tensor.size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + printf("size = %8.2f MB -> %8.2f MB | hist: ", tensor.size/1024.0/1024.0, new_size/1024.0/1024.0); int64_t tot_count = 0; - for (size_t i = 0; i < hist_cur.size(); i++) - { + for (size_t i = 0; i < hist_cur.size(); i++) { hist_all[i] += hist_cur[i]; tot_count += hist_cur[i]; } - if (tot_count > 0) - { - for (size_t i = 0; i < hist_cur.size(); i++) - { + if (tot_count > 0) { + for (size_t i = 0; i < hist_cur.size(); i++) { printf("%5.3f ", hist_cur[i] / float(nelements)); } } @@ -2931,21 +2482,18 @@ static void llama_model_quantize_internal(const std::string &fname_inp, const st file_saver.write_tensor(tensor, new_type, new_data, new_size); } - printf("%s: model size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0); - printf("%s: quant size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0); + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); { int64_t sum_all = 0; - for (size_t i = 0; i < hist_all.size(); i++) - { + for (size_t i = 0; i < hist_all.size(); i++) { sum_all += hist_all[i]; } - if (sum_all > 0) - { + if (sum_all > 0) { printf("%s: hist: ", __func__); - for (size_t i = 0; i < hist_all.size(); i++) - { + for (size_t i = 0; i < hist_all.size(); i++) { printf("%5.3f ", hist_all[i] / float(sum_all)); } printf("\n"); @@ -2959,20 +2507,18 @@ static void llama_model_quantize_internal(const std::string &fname_inp, const st // interface implementation // -struct llama_model *llama_load_model_from_file( - const char *path_model, - struct llama_context_params params) -{ +struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_context_params params) { ggml_time_init(); - llama_model *model = new llama_model; + llama_model * model = new llama_model; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock, - params.vocab_only, params.progress_callback, params.progress_callback_user_data)) - { + params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock, + params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { delete model; fprintf(stderr, "%s: failed to load model\n", __func__); return nullptr; @@ -2981,42 +2527,35 @@ struct llama_model *llama_load_model_from_file( return model; } -void llama_free_model(struct llama_model *model) -{ +void llama_free_model(struct llama_model * model) { delete model; } -struct llama_context *llama_new_context_with_model( - struct llama_model *model, - struct llama_context_params params) -{ +struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params) { - if (!model) - { + if (!model) { return nullptr; } - llama_context *ctx = new llama_context(*model, model->vocab); + llama_context * ctx = new llama_context(*model, model->vocab); if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } unsigned cur_percentage = 0; - if (params.progress_callback == NULL) - { + if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; - params.progress_callback = [](float progress, void *ctx) - { - unsigned *cur_percentage_p = (unsigned *)ctx; - unsigned percentage = (unsigned)(100 * progress); - while (percentage > *cur_percentage_p) - { + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { *cur_percentage_p = percentage; fprintf(stderr, "."); fflush(stderr); - if (percentage >= 100) - { + if (percentage >= 100) { fprintf(stderr, "\n"); } } @@ -3029,10 +2568,8 @@ struct llama_context *llama_new_context_with_model( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers - if (!params.vocab_only) - { - if (!kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) - { + if (!params.vocab_only) { + if (!kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -3043,20 +2580,16 @@ struct llama_context *llama_new_context_with_model( fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - const auto &hparams = ctx->model.hparams; + const auto & hparams = ctx->model.hparams; // resized during inference - if (params.logits_all) - { - ctx->logits.reserve(hparams.n_ctx * hparams.n_vocab); - } - else - { + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { ctx->logits.reserve(hparams.n_vocab); } - if (params.embedding) - { + if (params.embedding){ ctx->embedding.resize(hparams.n_embd); } @@ -3067,41 +2600,36 @@ struct llama_context *llama_new_context_with_model( } #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) - { + if (params.n_gpu_layers > 0) { // this allocates all Metal resources and memory buffers ctx->ctx_metal = ggml_metal_init(); - void *data_ptr = NULL; + void * data_ptr = NULL; size_t data_size = 0; - if (params.use_mmap) - { - data_ptr = ctx->model.mapping->addr; + if (params.use_mmap) { + data_ptr = ctx->model.mapping->addr; data_size = ctx->model.mapping->size; - } - else - { - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size(ctx->model.ctx); + } else { + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); } const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - printf("%s: max tensor size = %8.2f MB\n", __func__, max_size / 1024.0 / 1024.0); + printf("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); -#define LLAMA_METAL_CHECK_BUF(result) \ - if (!(result)) \ - { \ - fprintf(stderr, "%s: failed to add buffer\n", __func__); \ - llama_free(ctx); \ - return NULL; \ +#define LLAMA_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + fprintf(stderr, "%s: failed to add buffer\n", __func__); \ + llama_free(ctx); \ + return NULL; \ } LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0)); @@ -3112,56 +2640,46 @@ struct llama_context *llama_new_context_with_model( return ctx; } -struct llama_context *llama_init_from_file( - const char *path_model, - struct llama_context_params params) -{ +struct llama_context * llama_init_from_file( + const char * path_model, + struct llama_context_params params) { - struct llama_model *model = llama_load_model_from_file(path_model, params); - if (!model) - { + struct llama_model * model = llama_load_model_from_file(path_model, params); + if (!model) { return nullptr; } - struct llama_context *ctx = llama_new_context_with_model(model, params); + struct llama_context * ctx = llama_new_context_with_model(model, params); ctx->model_owner = true; return ctx; } -void llama_free(struct llama_context *ctx) -{ - if (ctx->model_owner) - { +void llama_free(struct llama_context * ctx) { + if (ctx->model_owner) { delete &ctx->model; } delete ctx; } int llama_model_quantize( - const char *fname_inp, - const char *fname_out, - const llama_model_quantize_params *params) -{ - try - { + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params *params) { + try { llama_model_quantize_internal(fname_inp, fname_out, params); return 0; - } - catch (const std::exception &err) - { + } catch (const std::exception & err) { fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.what()); return 1; } } -int llama_apply_lora_from_file_internal(const struct llama_model &model, const char *path_lora, const char *path_base_model, int n_threads) -{ +int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); const int64_t t_start_lora_us = ggml_time_us(); auto fin = std::ifstream(path_lora, std::ios::binary); - if (!fin) - { + if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora); return 1; } @@ -3169,39 +2687,38 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c // verify magic and version { uint32_t magic; - fin.read((char *)&magic, sizeof(magic)); - if (magic != LLAMA_FILE_MAGIC_GGLA) - { + fin.read((char *) &magic, sizeof(magic)); + if (magic != LLAMA_FILE_MAGIC_GGLA) { fprintf(stderr, "%s: bad file magic\n", __func__); return 1; } uint32_t format_version; - fin.read((char *)&format_version, sizeof(format_version)); + fin.read((char *) &format_version, sizeof(format_version)); - if (format_version != 1) - { - fprintf(stderr, "%s: unsupported file version\n", __func__); + if (format_version != 1) { + fprintf(stderr, "%s: unsupported file version\n", __func__ ); return 1; } } int32_t lora_r; int32_t lora_alpha; - fin.read((char *)&lora_r, sizeof(lora_r)); - fin.read((char *)&lora_alpha, sizeof(lora_alpha)); + fin.read((char *) &lora_r, sizeof(lora_r)); + fin.read((char *) &lora_alpha, sizeof(lora_alpha)); float scaling = (float)lora_alpha / (float)lora_r; fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); + // create a temporary ggml context to store the lora tensors // todo: calculate size from biggest possible tensor std::vector lora_buf(1024ull * 1024ull * 1024ull); struct ggml_init_params params; - params.mem_size = lora_buf.size(); + params.mem_size = lora_buf.size(); params.mem_buffer = lora_buf.data(); - params.no_alloc = false; + params.no_alloc = false; - ggml_context *lora_ctx = ggml_init(params); + ggml_context * lora_ctx = ggml_init(params); std::unordered_map lora_tensors; // create a name -> tensor map of the model to accelerate lookups @@ -3210,12 +2727,12 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c model_tensors.insert(kv); } + // load base model std::unique_ptr model_loader; - ggml_context *base_ctx = NULL; + ggml_context * base_ctx = NULL; llama_buffer base_buf; - if (path_base_model) - { + if (path_base_model) { fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model); model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); @@ -3225,9 +2742,9 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c base_buf.resize(ctx_size); ggml_init_params base_params; - base_params.mem_size = base_buf.size; + base_params.mem_size = base_buf.size; base_params.mem_buffer = base_buf.addr; - base_params.no_alloc = model_loader->use_mmap; + base_params.no_alloc = model_loader->use_mmap; base_ctx = ggml_init(base_params); @@ -3242,23 +2759,20 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c // read tensors and apply bool warned = false; int n_tensors = 0; - while (true) - { + while (true) { int32_t n_dims; int32_t length; int32_t ftype; fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - if (fin.eof()) - { + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (fin.eof()) { break; } - int32_t ne[2] = {1, 1}; - for (int i = 0; i < n_dims; ++i) - { + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); } @@ -3272,8 +2786,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c // check for lora suffix and get the type of tensor const std::string lora_suffix = ".lora"; size_t pos = name.rfind(lora_suffix); - if (pos == std::string::npos) - { + if (pos == std::string::npos) { fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); return 1; } @@ -3283,34 +2796,28 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c base_name.erase(pos); // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); - if (model_tensors.find(base_name) == model_tensors.end()) - { + if (model_tensors.find(base_name) == model_tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); return 1; } // create ggml tensor ggml_type wtype; - switch (ftype) - { - case 0: - wtype = GGML_TYPE_F32; - break; - case 1: - wtype = GGML_TYPE_F16; - break; - default: - { - fprintf(stderr, "%s: invalid tensor data type '%d'\n", - __func__, ftype); - return false; + switch (ftype) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + default: + { + fprintf(stderr, "%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } } ggml_tensor * lora_tensor; if (n_dims == 2) { lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); } - else - { + else { fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims); return 1; } @@ -3321,14 +2828,13 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c size_t tensor_data_size = ggml_nbytes(lora_tensor); offset = (offset + 31) & -32; fin.seekg(offset); - fin.read((char *)lora_tensor->data, tensor_data_size); + fin.read((char*)lora_tensor->data, tensor_data_size); lora_tensors[name] = lora_tensor; // check if we have both A and B tensors and apply if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && - lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) - { + lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { ggml_tensor * dest_t = model_tensors[base_name]; @@ -3349,30 +2855,25 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c ggml_tensor * base_t; if (model_loader) { // load from base model - if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) - { + if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) { fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); return 1; } size_t idx = model_loader->tensors_map.name_to_idx[base_name]; - llama_load_tensor < = model_loader->tensors_map.tensors[idx]; - base_t = model_loader->get_tensor(base_name, {(uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1]}, GGML_BACKEND_CPU); - lt.data = (uint8_t *)lt.ggml_tensor->data; + llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; + base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); + lt.data = (uint8_t *) lt.ggml_tensor->data; model_loader->load_data_for(lt); lt.ggml_tensor->data = lt.data; } - else - { + else { base_t = dest_t; } - if (ggml_is_quantized(base_t->type)) - { - if (!warned) - { + if (ggml_is_quantized(base_t->type)) { + if (!warned) { fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, " - "use a f16 or f32 base model with --lora-base\n", - __func__); + "use a f16 or f32 base model with --lora-base\n", __func__); warned = true; } } @@ -3385,11 +2886,9 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c GGML_ASSERT(loraB->type == GGML_TYPE_F32); ggml_set_name(loraB, "loraB"); - if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) - { + if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" - " are you sure that this adapter is for this model?\n", - __func__, base_t->ne[0], loraA->ne[1]); + " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); return 1; } @@ -3407,15 +2906,13 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c ggml_set_name(BA, "BA_scaled"); } - ggml_tensor *r; - if (base_t == dest_t) - { + ggml_tensor * r; + if (base_t == dest_t) { r = ggml_add_inplace(lora_ctx, dest_t, BA); offload_func_force_inplace(r); ggml_set_name(r, "r_add_inplace"); } - else - { + else { r = ggml_add(lora_ctx, base_t, BA); offload_func(r); ggml_set_name(r, "r_add"); @@ -3435,8 +2932,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c lora_tensors.clear(); n_tensors++; - if (n_tensors % 4 == 0) - { + if (n_tensors % 4 == 0) { fprintf(stderr, "."); } } @@ -3444,8 +2940,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c // TODO: this should be in a destructor, it will leak on failure ggml_free(lora_ctx); - if (base_ctx) - { + if (base_ctx) { ggml_free(base_ctx); } @@ -3455,38 +2950,29 @@ int llama_apply_lora_from_file_internal(const struct llama_model &model, const c return 0; } -int llama_apply_lora_from_file(struct llama_context *ctx, const char *path_lora, const char *path_base_model, int n_threads) -{ - try - { +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { + try { return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads); - } - catch (const std::exception &err) - { + } catch (const std::exception & err) { fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what()); return 1; } } -int llama_model_apply_lora_from_file(const struct llama_model *model, const char *path_lora, const char *path_base_model, int n_threads) -{ - try - { +int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, const char * path_base_model, int n_threads) { + try { return llama_apply_lora_from_file_internal(*model, path_lora, path_base_model, n_threads); - } - catch (const std::exception &err) - { + } catch (const std::exception & err) { fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what()); return 1; } } -int llama_get_kv_cache_token_count(const struct llama_context *ctx) -{ +int llama_get_kv_cache_token_count(const struct llama_context * ctx) { return ctx->kv_self.n; } -#define LLAMA_MAX_RNG_STATE (64 * 1024) +#define LLAMA_MAX_RNG_STATE (64*1024) void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { @@ -3496,30 +2982,39 @@ void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { } // Returns the *maximum* size of the state -size_t llama_get_state_size(const struct llama_context *ctx) -{ +size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. - const size_t s_rng_size = sizeof(size_t); - const size_t s_rng = LLAMA_MAX_RNG_STATE; + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = LLAMA_MAX_RNG_STATE; const size_t s_logits_capacity = sizeof(size_t); - const size_t s_logits_size = sizeof(size_t); - const size_t s_logits = ctx->logits.capacity() * sizeof(float); - const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embedding.size() * sizeof(float); - const size_t s_kv_size = sizeof(size_t); - const size_t s_kv_ntok = sizeof(int); - const size_t s_kv = ctx->kv_self.buf.size; - - const size_t s_total = (+s_rng_size + s_rng + s_logits_capacity + s_logits_size + s_logits + s_embedding_size + s_embedding + s_kv_size + s_kv_ntok + s_kv); + const size_t s_logits_size = sizeof(size_t); + const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_embedding_size = sizeof(size_t); + const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = ctx->kv_self.buf.size; + + const size_t s_total = ( + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv + ); return s_total; } // Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context *ctx, uint8_t *dst) -{ - uint8_t *out = dst; +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + uint8_t * out = dst; // copy rng { @@ -3532,24 +3027,19 @@ size_t llama_copy_state_data(struct llama_context *ctx, uint8_t *dst) memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - memcpy(out, &rng_size, sizeof(rng_size)); - out += sizeof(rng_size); - memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); - out += LLAMA_MAX_RNG_STATE; + memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); + memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; } // copy logits { - const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_cap = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); - memcpy(out, &logits_cap, sizeof(logits_cap)); - out += sizeof(logits_cap); - memcpy(out, &logits_size, sizeof(logits_size)); - out += sizeof(logits_size); + memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap); + memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); - if (logits_size) - { + if (logits_size) { memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); } @@ -3560,11 +3050,9 @@ size_t llama_copy_state_data(struct llama_context *ctx, uint8_t *dst) { const size_t embedding_size = ctx->embedding.size(); - memcpy(out, &embedding_size, sizeof(embedding_size)); - out += sizeof(embedding_size); + memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); - if (embedding_size) - { + if (embedding_size) { memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); } @@ -3572,43 +3060,40 @@ size_t llama_copy_state_data(struct llama_context *ctx, uint8_t *dst) // copy kv cache { - const auto &kv_self = ctx->kv_self; - const auto &hparams = ctx->model.hparams; - const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd; - const int n_ctx = hparams.n_ctx; + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd; + const int n_ctx = hparams.n_ctx; const size_t kv_size = kv_self.buf.size; - const int kv_ntok = llama_get_kv_cache_token_count(ctx); + const int kv_ntok = llama_get_kv_cache_token_count(ctx); - memcpy(out, &kv_size, sizeof(kv_size)); - out += sizeof(kv_size); - memcpy(out, &kv_ntok, sizeof(kv_ntok)); - out += sizeof(kv_ntok); + memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); + memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); - if (kv_size) - { + if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); - ggml_context *cpy_ctx = ggml_init({4096, NULL, /* no_alloc */ true}); + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; - ggml_tensor *kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); kout3d->data = out; out += ggml_nbytes(kout3d); - ggml_tensor *vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); vout3d->data = out; out += ggml_nbytes(vout3d); - ggml_tensor *k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, - elt_size * n_embd, elt_size * n_embd * n_ctx, 0); + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); - ggml_tensor *v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, - elt_size * n_ctx, elt_size * n_ctx * n_embd, 0); + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); @@ -3618,7 +3103,7 @@ size_t llama_copy_state_data(struct llama_context *ctx, uint8_t *dst) } } - const size_t written = out - dst; + const size_t written = out - dst; const size_t max_size = llama_get_state_size(ctx); LLAMA_ASSERT(written <= max_size); @@ -3627,19 +3112,16 @@ size_t llama_copy_state_data(struct llama_context *ctx, uint8_t *dst) } // Sets the state reading from the specified source address -size_t llama_set_state_data(struct llama_context *ctx, uint8_t *src) -{ - uint8_t *inp = src; +size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { + uint8_t * inp = src; // set rng { size_t rng_size; - char rng_buf[LLAMA_MAX_RNG_STATE]; + char rng_buf[LLAMA_MAX_RNG_STATE]; - memcpy(&rng_size, inp, sizeof(rng_size)); - inp += sizeof(rng_size); - memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); - inp += LLAMA_MAX_RNG_STATE; + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; std::stringstream rng_ss; rng_ss.str(std::string(&rng_buf[0], rng_size)); @@ -3653,15 +3135,12 @@ size_t llama_set_state_data(struct llama_context *ctx, uint8_t *src) size_t logits_cap; size_t logits_size; - memcpy(&logits_cap, inp, sizeof(logits_cap)); - inp += sizeof(logits_cap); - memcpy(&logits_size, inp, sizeof(logits_size)); - inp += sizeof(logits_size); + memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); + memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); - if (logits_size) - { + if (logits_size) { ctx->logits.resize(logits_size); memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); } @@ -3673,13 +3152,11 @@ size_t llama_set_state_data(struct llama_context *ctx, uint8_t *src) { size_t embedding_size; - memcpy(&embedding_size, inp, sizeof(embedding_size)); - inp += sizeof(embedding_size); + memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); - if (embedding_size) - { + if (embedding_size) { memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); inp += embedding_size * sizeof(float); } @@ -3687,45 +3164,42 @@ size_t llama_set_state_data(struct llama_context *ctx, uint8_t *src) // set kv cache { - const auto &kv_self = ctx->kv_self; - const auto &hparams = ctx->model.hparams; - const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd; - const int n_ctx = hparams.n_ctx; + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd; + const int n_ctx = hparams.n_ctx; size_t kv_size; int kv_ntok; - memcpy(&kv_size, inp, sizeof(kv_size)); - inp += sizeof(kv_size); - memcpy(&kv_ntok, inp, sizeof(kv_ntok)); - inp += sizeof(kv_ntok); + memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); + memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); - if (kv_size) - { + if (kv_size) { LLAMA_ASSERT(kv_self.buf.size == kv_size); const size_t elt_size = ggml_element_size(kv_self.k); - ggml_context *cpy_ctx = ggml_init({4096, NULL, /* no_alloc */ true}); + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; - ggml_tensor *kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); - kin3d->data = (void *)inp; + ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + kin3d->data = (void *) inp; inp += ggml_nbytes(kin3d); - ggml_tensor *vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); - vin3d->data = (void *)inp; + ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + vin3d->data = (void *) inp; inp += ggml_nbytes(vin3d); - ggml_tensor *k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, - elt_size * n_embd, elt_size * n_embd * n_ctx, 0); + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); - ggml_tensor *v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, - elt_size * n_ctx, elt_size * n_ctx * n_embd, 0); + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); @@ -3737,7 +3211,7 @@ size_t llama_set_state_data(struct llama_context *ctx, uint8_t *src) ctx->kv_self.n = kv_ntok; } - const size_t nread = inp - src; + const size_t nread = inp - src; const size_t max_size = llama_get_state_size(ctx); LLAMA_ASSERT(nread <= max_size); @@ -3745,17 +3219,15 @@ size_t llama_set_state_data(struct llama_context *ctx, uint8_t *src) return nread; } -bool llama_load_session_file(struct llama_context *ctx, const char *path_session, llama_token *tokens_out, size_t n_token_capacity, size_t *n_token_count_out) -{ +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks { - const uint32_t magic = file.read_u32(); + const uint32_t magic = file.read_u32(); const uint32_t version = file.read_u32(); - if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) - { + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); return false; } @@ -3763,8 +3235,7 @@ bool llama_load_session_file(struct llama_context *ctx, const char *path_session llama_hparams session_hparams; file.read_raw(&session_hparams, sizeof(llama_hparams)); - if (session_hparams != ctx->model.hparams) - { + if (session_hparams != ctx->model.hparams) { fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__); return false; } @@ -3774,8 +3245,7 @@ bool llama_load_session_file(struct llama_context *ctx, const char *path_session { const uint32_t n_token_count = file.read_u32(); - if (n_token_count > n_token_capacity) - { + if (n_token_count > n_token_capacity) { fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); return false; } @@ -3789,8 +3259,7 @@ bool llama_load_session_file(struct llama_context *ctx, const char *path_session const size_t n_state_size_cur = file.size - file.tell(); const size_t n_state_size_max = llama_get_state_size(ctx); - if (n_state_size_cur > n_state_size_max) - { + if (n_state_size_cur > n_state_size_max) { fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); return false; } @@ -3804,8 +3273,7 @@ bool llama_load_session_file(struct llama_context *ctx, const char *path_session return true; } -bool llama_save_session_file(struct llama_context *ctx, const char *path_session, const llama_token *tokens, size_t n_token_count) -{ +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -3814,7 +3282,7 @@ bool llama_save_session_file(struct llama_context *ctx, const char *path_session file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); // save the prompt - file.write_u32((uint32_t)n_token_count); + file.write_u32((uint32_t) n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state @@ -3865,8 +3333,7 @@ int llama_eval_embd( // get a more accurate load time, upon first eval // TODO: fix this - if (!ctx->has_evaluated_once) - { + if (!ctx->has_evaluated_once) { ctx->t_load_us = ggml_time_us() - ctx->t_start_us; ctx->has_evaluated_once = true; } @@ -3874,10 +3341,9 @@ int llama_eval_embd( return 0; } -int llama_eval_export(struct llama_context *ctx, const char *fname) -{ +int llama_eval_export(struct llama_context * ctx, const char * fname) { const int n_batch = 1; - const int n_ctx = 512 - n_batch; + const int n_ctx = 512 - n_batch; const std::vector tmp(n_batch, llama_token_bos()); @@ -3890,99 +3356,84 @@ int llama_eval_export(struct llama_context *ctx, const char *fname) } int llama_tokenize( - struct llama_context *ctx, - const char *text, - llama_token *tokens, - int n_max_tokens, - bool add_bos) -{ + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { auto res = llama_tokenize(ctx->vocab, text, add_bos); - if (n_max_tokens < (int)res.size()) - { + if (n_max_tokens < (int) res.size()) { fprintf(stderr, "%s: too many tokens\n", __func__); - return -((int)res.size()); + return -((int) res.size()); } - for (size_t i = 0; i < res.size(); i++) - { + for (size_t i = 0; i < res.size(); i++) { tokens[i] = res[i]; } return res.size(); } -int llama_n_vocab(const struct llama_context *ctx) -{ +int llama_n_vocab(const struct llama_context * ctx) { return ctx->vocab.id_to_token.size(); } -int llama_n_ctx(const struct llama_context *ctx) -{ +int llama_n_ctx(const struct llama_context * ctx) { return ctx->model.hparams.n_ctx; } -int llama_n_embd(const struct llama_context *ctx) -{ +int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } int llama_get_vocab( - const struct llama_context *ctx, - const char **strings, - float *scores, - int capacity) -{ - int n = std::min(capacity, (int)ctx->vocab.id_to_token.size()); - for (int i = 0; i < n; ++i) - { + const struct llama_context * ctx, + const char * * strings, + float * scores, + int capacity) { + int n = std::min(capacity, (int) ctx->vocab.id_to_token.size()); + for (int i = 0; ivocab.id_to_token[i].tok.c_str(); - scores[i] = ctx->vocab.id_to_token[i].score; + scores[i] = ctx->vocab.id_to_token[i].score; } return n; } -float *llama_get_logits(struct llama_context *ctx) -{ +float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } -float *llama_get_embeddings(struct llama_context *ctx) -{ +float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } -const char *llama_token_to_str(const struct llama_context *ctx, llama_token token) -{ - if (token >= llama_n_vocab(ctx)) - { +const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { + if (token >= llama_n_vocab(ctx)) { return nullptr; } return ctx->vocab.id_to_token[token].tok.c_str(); } -llama_token llama_token_bos() -{ +llama_token llama_token_bos() { return 1; } -llama_token llama_token_eos() -{ +llama_token llama_token_eos() { return 2; } -llama_token llama_token_nl() -{ +llama_token llama_token_nl() { return 13; } -void llama_print_timings(struct llama_context *ctx) -{ + +void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_eval = std::max(1, ctx->n_eval); + const int32_t n_eval = std::max(1, ctx->n_eval); const int32_t n_p_eval = std::max(1, ctx->n_p_eval); fprintf(stderr, "\n"); @@ -3992,43 +3443,40 @@ void llama_print_timings(struct llama_context *ctx) fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval, 1e6 / ctx->t_p_eval_us * n_p_eval); fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval, 1e6 / ctx->t_eval_us * n_eval); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us) / 1000.0); + __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval, 1e6 / ctx->t_eval_us * n_eval); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); } -void llama_reset_timings(struct llama_context *ctx) -{ +void llama_reset_timings(struct llama_context * ctx) { ctx->t_start_us = ggml_time_us(); ctx->t_sample_us = ctx->n_sample = 0; - ctx->t_eval_us = ctx->n_eval = 0; + ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; } -const char *llama_print_system_info(void) -{ +const char * llama_print_system_info(void) { static std::string s; - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; return s.c_str(); } // For internal test use -const std::vector> &llama_internal_get_tensor_map(struct llama_context *ctx) -{ +const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { return ctx->model.tensors_by_name; -} +} \ No newline at end of file From c934f7076857930e8ebb33876040b10d85752d41 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 10 Jul 2023 12:31:37 +0800 Subject: [PATCH 010/180] updated our llama.cpp to be the same as master branch --- llama.cpp | 440 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 243 insertions(+), 197 deletions(-) diff --git a/llama.cpp b/llama.cpp index 45f7976587714..ffa447bae7236 100644 --- a/llama.cpp +++ b/llama.cpp @@ -66,6 +66,7 @@ enum e_model { MODEL_65B, }; +static const size_t kB = 1024; static const size_t MB = 1024*1024; // computed for n_ctx == 2048 @@ -78,14 +79,33 @@ void llama_nop(struct ggml_tensor * tensor) { // don't offload by default (void) tensor; } +// +// ggml helpers +// + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +// +// memory sizes +// + static const std::map & MEM_REQ_SCRATCH0() { static std::map k_sizes = { - { MODEL_3B, 256ull * MB }, - { MODEL_7B, 512ull * MB }, - { MODEL_13B, 512ull * MB }, - { MODEL_30B, 512ull * MB }, - { MODEL_65B, 1024ull * MB }, + { MODEL_3B, 256ull * MB }, + { MODEL_7B, 512ull * MB }, + { MODEL_13B, 512ull * MB }, + { MODEL_30B, 512ull * MB }, + { MODEL_65B, 1024ull * MB }, }; return k_sizes; } @@ -93,11 +113,11 @@ static const std::map & MEM_REQ_SCRATCH0() static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { - { MODEL_3B, 256ull * MB }, - { MODEL_7B, 512ull * MB }, - { MODEL_13B, 512ull * MB }, - { MODEL_30B, 512ull * MB }, - { MODEL_65B, 1024ull * MB }, + { MODEL_3B, 256ull * MB }, + { MODEL_7B, 512ull * MB }, + { MODEL_13B, 512ull * MB }, + { MODEL_30B, 512ull * MB }, + { MODEL_65B, 1024ull * MB }, }; return k_sizes; } @@ -106,11 +126,11 @@ static const std::map & MEM_REQ_SCRATCH1() static const std::map & MEM_REQ_KV_SELF() { static std::map k_sizes = { - { MODEL_3B, 682ull * MB }, - { MODEL_7B, 1026ull * MB }, - { MODEL_13B, 1608ull * MB }, - { MODEL_30B, 3124ull * MB }, - { MODEL_65B, 5120ull * MB }, + { MODEL_3B, 682ull * MB }, + { MODEL_7B, 1026ull * MB }, + { MODEL_13B, 1608ull * MB }, + { MODEL_30B, 3124ull * MB }, + { MODEL_65B, 5120ull * MB }, }; return k_sizes; } @@ -120,11 +140,11 @@ static const std::map & MEM_REQ_KV_SELF() static const std::map & MEM_REQ_EVAL() { static std::map k_sizes = { - { MODEL_3B, 512ull * MB }, - { MODEL_7B, 768ull * MB }, - { MODEL_13B, 1024ull * MB }, - { MODEL_30B, 1280ull * MB }, - { MODEL_65B, 1536ull * MB }, + { MODEL_3B, 512ull * MB }, + { MODEL_7B, 768ull * MB }, + { MODEL_13B, 1024ull * MB }, + { MODEL_30B, 1280ull * MB }, + { MODEL_65B, 1536ull * MB }, }; return k_sizes; } @@ -134,11 +154,11 @@ static const std::map & MEM_REQ_EVAL() static const std::map & VRAM_REQ_SCRATCH_BASE() { static std::map k_sizes = { - { MODEL_3B, 512ull * kB }, - { MODEL_7B, 512ull * kB }, - { MODEL_13B, 640ull * kB }, - { MODEL_30B, 768ull * kB }, - { MODEL_65B, 1536ull * kB }, + { MODEL_3B, 512ull * kB }, + { MODEL_7B, 512ull * kB }, + { MODEL_13B, 640ull * kB }, + { MODEL_30B, 768ull * kB }, + { MODEL_65B, 1536ull * kB }, }; return k_sizes; } @@ -148,11 +168,11 @@ static const std::map & VRAM_REQ_SCRATCH_BASE() static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() { static std::map k_sizes = { - { MODEL_3B, 128ull }, - { MODEL_7B, 128ull }, - { MODEL_13B, 160ull }, - { MODEL_30B, 208ull }, - { MODEL_65B, 416ull }, + { MODEL_3B, 128ull }, + { MODEL_7B, 128ull }, + { MODEL_13B, 160ull }, + { MODEL_30B, 208ull }, + { MODEL_65B, 416ull }, }; return k_sizes; } @@ -193,8 +213,8 @@ struct llama_layer { }; struct llama_kv_cache { - struct ggml_tensor * k; - struct ggml_tensor * v; + struct ggml_tensor * k = NULL; + struct ggml_tensor * v = NULL; struct ggml_context * ctx = NULL; @@ -281,6 +301,13 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} +#ifdef GGML_USE_METAL + ~llama_context() { + if (ctx_metal) { + ggml_metal_free(ctx_metal); + } + } +#endif std::mt19937 rng; bool has_evaluated_once = false; @@ -365,7 +392,7 @@ static T checked_mul(T a, T b) { T ret = a * b; if (a != 0 && ret / a != b) { throw std::runtime_error(format("overflow multiplying %llu * %llu", - (unsigned long long) a, (unsigned long long) b)); + (unsigned long long) a, (unsigned long long) b)); } return ret; } @@ -425,7 +452,7 @@ struct llama_file_loader { llama_vocab vocab; llama_file_loader(const char * fname, llama_load_tensors_map & tensors_map) - : file(fname, "rb") { + : file(fname, "rb") { fprintf(stderr, "llama.cpp: loading model from %s\n", fname); read_magic(); read_hparams(); @@ -457,7 +484,7 @@ struct llama_file_loader { } throw std::runtime_error(format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", - magic, version)); + magic, version)); } void read_hparams() { hparams.n_vocab = file.read_u32(); @@ -476,9 +503,7 @@ struct llama_file_loader { std::string word = file.read_string(len); float score = 0.0f; - if (file_version >= LLAMA_FILE_VERSION_GGMF_V1) { - file.read_raw(&score, sizeof(score)); - } + file.read_raw(&score, sizeof(score)); vocab.token_to_id[word] = i; @@ -536,7 +561,7 @@ struct llama_file_saver { llama_file file; llama_file_loader * any_file_loader; llama_file_saver(const char * fname, llama_file_loader * any_file_loader, enum llama_ftype new_ftype) - : file(fname, "wb"), any_file_loader(any_file_loader) { + : file(fname, "wb"), any_file_loader(any_file_loader) { fprintf(stderr, "llama.cpp: saving model to %s\n", fname); write_magic(); write_hparams(new_ftype); @@ -628,7 +653,7 @@ struct llama_model_loader { llama_load_tensor & lt = tensors_map.tensors.at(it->second); if (lt.ne != ne) { throw std::runtime_error(format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", - name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str())); + name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str())); } return get_tensor_for(lt, backend); @@ -706,7 +731,7 @@ struct llama_model_loader { } break; #if defined(GGML_USE_CUBLAS) - case GGML_BACKEND_GPU: + case GGML_BACKEND_GPU: case GGML_BACKEND_GPU_SPLIT: ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); if (!use_mmap) { @@ -714,7 +739,7 @@ struct llama_model_loader { } break; #elif defined(GGML_USE_CLBLAST) - case GGML_BACKEND_GPU: + case GGML_BACKEND_GPU: ggml_cl_transform_tensor(lt.data, lt.ggml_tensor); if (!use_mmap) { free(lt.data); @@ -755,17 +780,16 @@ struct llama_model_loader { }; - // // kv cache // static bool kv_cache_init( const struct llama_hparams & hparams, - struct llama_kv_cache & cache, - ggml_type wtype, - int n_ctx, - int n_gpu_layers) { + struct llama_kv_cache & cache, + ggml_type wtype, + int n_ctx, + int n_gpu_layers) { const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; @@ -807,21 +831,21 @@ static bool kv_cache_init( struct llama_context_params llama_context_default_params() { struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, - /*.gpu_layers =*/ 0, - /*.main_gpu =*/ 0, - /*.tensor_split =*/ {0}, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, - /*.low_vram =*/ false, - /*.f16_kv =*/ true, - /*.logits_all =*/ false, - /*.vocab_only =*/ false, - /*.use_mmap =*/ true, - /*.use_mlock =*/ false, - /*.embedding =*/ false, + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_ctx =*/ 512, + /*.n_batch =*/ 512, + /*.gpu_layers =*/ 0, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ {0}, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.low_vram =*/ false, + /*.f16_kv =*/ true, + /*.logits_all =*/ false, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + /*.embedding =*/ false, }; return result; @@ -829,10 +853,10 @@ struct llama_context_params llama_context_default_params() { struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { - /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, - /*.allow_requantize =*/ false, - /*.quantize_output_tensor =*/ true, + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, }; return result; @@ -888,11 +912,11 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: - return "mostly Q4_1, some F16"; + return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; - // K-quants + // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; @@ -952,11 +976,11 @@ static void llama_model_load_internal( case 60: model.type = e_model::MODEL_30B; break; case 80: model.type = e_model::MODEL_65B; break; default: - { - if (hparams.n_layer < 32) { - model.type = e_model::MODEL_7B; - } - } break; + { + if (hparams.n_layer < 32) { + model.type = e_model::MODEL_7B; + } + } break; } hparams.n_ctx = n_ctx; @@ -1014,9 +1038,9 @@ static void llama_model_load_internal( } struct ggml_init_params params = { - /*.mem_size =*/ model.buf.size, - /*.mem_buffer =*/ model.buf.addr, - /*.no_alloc =*/ ml->use_mmap, + /*.mem_size =*/ model.buf.size, + /*.mem_buffer =*/ model.buf.addr, + /*.no_alloc =*/ ml->use_mmap, }; model.ctx = ggml_init(params); @@ -1107,9 +1131,9 @@ static void llama_model_load_internal( if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); } } } @@ -1122,15 +1146,15 @@ static void llama_model_load_internal( // this is the total memory required to run the inference const size_t mem_required = - ctx_size + - mmapped_size - vram_weights + // weights in VRAM not in memory - MEM_REQ_SCRATCH0().at(model.type) + - MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at (model.type); + ctx_size + + mmapped_size - vram_weights + // weights in VRAM not in memory + MEM_REQ_SCRATCH0().at(model.type) + + MEM_REQ_SCRATCH1().at(model.type) + + MEM_REQ_EVAL().at (model.type); // this is the memory required by one llama_state const size_t mem_required_state = - scale*MEM_REQ_KV_SELF().at(model.type); + scale*MEM_REQ_KV_SELF().at(model.type); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -1142,11 +1166,14 @@ static void llama_model_load_internal( fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); ggml_cuda_set_scratch_size(0); // disable scratch } else { - vram_scratch = n_batch * MB; + const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type); + const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type); + vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context); ggml_cuda_set_scratch_size(vram_scratch); if (n_gpu_layers > 0) { - fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n", - __func__, vram_scratch / MB); + fprintf(stderr, "%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n", + __func__, vram_scratch_base / kB, vram_scratch_per_context, + (vram_scratch + MB - 1) / MB); // round up } } #endif // GGML_USE_CUBLAS @@ -1159,6 +1186,10 @@ static void llama_model_load_internal( fprintf(stderr, "%s: offloading non-repeating layers to GPU\n", __func__); } size_t vram_kv_cache = 0; + +#ifdef GGML_USE_CUBLAS + const int max_backend_supported_layers = hparams.n_layer + 3; + const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; if (n_gpu_layers > (int) hparams.n_layer + 1) { if (low_vram) { fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); @@ -1186,7 +1217,7 @@ static void llama_model_load_internal( __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up #else (void) n_gpu_layers; -#endif +#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) } // populate `tensors_by_name` @@ -1250,22 +1281,16 @@ static bool llama_model_load( // - n_threads: number of threads to use // static bool llama_eval_internal( - llama_context & lctx, - const llama_token * tokens, - const float * embd, - const int n_tokens, - const int n_past, - int n_threads, - const char * cgraph_fname) { + llama_context & lctx, + const llama_token * tokens, + const float * embd, + const int n_tokens, + const int n_past, + int n_threads, + const char * cgraph_fname) { LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); - // enforce that the first token is BOS - if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) { - fprintf(stderr, "%s: first token must be BOS\n", __func__); - return false; - } - const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1289,9 +1314,9 @@ static bool llama_eval_internal( auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.addr, - /*.no_alloc =*/ false, + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.addr, + /*.no_alloc =*/ false, }; struct ggml_context * ctx0 = ggml_init(params); @@ -1328,7 +1353,7 @@ static bool llama_eval_internal( offload_func_t offload_func_v = llama_nop; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { + if (n_gpu_layers > n_layer) { offload_func_nr = ggml_cuda_assign_buffers; } if (n_gpu_layers > n_layer + 1) { @@ -1400,8 +1425,8 @@ static bool llama_eval_internal( ggml_set_name(k, "k"); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -1411,18 +1436,18 @@ static bool llama_eval_internal( } struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); offload_func_kq(Q); ggml_set_name(Q, "Q"); struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), - n_embd/n_head, n_head, n_past + N), - 0, 2, 1, 3); + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); offload_func_kq(K); ggml_set_name(K, "K"); @@ -1452,11 +1477,11 @@ static bool llama_eval_internal( // split cached V into n_head heads struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd/n_head, n_head, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, + il*n_ctx*ggml_element_size(kv_self.v)*n_embd); offload_func_v(V); ggml_set_name(V, "V"); @@ -1479,15 +1504,15 @@ static bool llama_eval_internal( // cur = KQV_merged.contiguous().view(n_embd, N) cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); // projection (no bias) cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); + model.layers[il].wo, + cur); offload_func(cur); ggml_set_name(cur, "result_wo"); } @@ -1513,14 +1538,14 @@ static bool llama_eval_internal( } struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); + model.layers[il].w3, + cur); offload_func(tmp); ggml_set_name(tmp, "result_w3"); cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); + model.layers[il].w1, + cur); offload_func(cur); ggml_set_name(cur, "result_w1"); @@ -1534,8 +1559,8 @@ static bool llama_eval_internal( ggml_set_name(cur, "silu_x_result_w3"); cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); + model.layers[il].w2, + cur); offload_func(cur); ggml_set_name(cur, "result_w2"); } @@ -1584,6 +1609,7 @@ static bool llama_eval_internal( #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { + ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_graph_compute(lctx.ctx_metal, &gf); ggml_metal_get_tensor (lctx.ctx_metal, cur); } else { @@ -2159,7 +2185,6 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ } llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - assert(ctx); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -2182,9 +2207,6 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok llama_sample_softmax(ctx, candidates); // Sample the next word X from the remaining words - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } llama_token X = llama_sample_token(ctx, candidates); t_start_sample_us = ggml_time_us(); @@ -2252,10 +2274,10 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam } float * f32_output = (float *) output.addr; - quantize_fns_t qtype; + ggml_type_traits_t qtype; if (ggml_is_quantized(tensor.type)) { - qtype = ggml_internal_get_quantize_fn(tensor.type); - if (qtype.dequantize_row_q == NULL) { + qtype = ggml_internal_get_type_traits(tensor.type); + if (qtype.to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type))); } } else if (tensor.type != GGML_TYPE_F16) { @@ -2266,7 +2288,7 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam if (tensor.type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements); } else if (ggml_is_quantized(tensor.type)) { - qtype.dequantize_row_q(tensor.data, f32_output, nelements); + qtype.to_float(tensor.data, f32_output, nelements); } else { LLAMA_ASSERT(false); // unreachable } @@ -2291,7 +2313,7 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam if (typ == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); } else { - qtype.dequantize_row_q(inbuf, outbuf, nels); + qtype.to_float(inbuf, outbuf, nels); } }; workers.push_back(std::thread(compute, tensor.type, tensor.data + in_buff_offs, f32_output + out_buff_offs, thr_elems)); @@ -2319,7 +2341,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; #ifdef GGML_USE_K_QUANTS - // K-quants + // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: @@ -2421,9 +2443,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; + use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && - (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; + (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; @@ -2546,8 +2568,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // struct llama_model * llama_load_model_from_file( - const char * path_model, - struct llama_context_params params) { + const char * path_model, + struct llama_context_params params) { ggml_time_init(); llama_model * model = new llama_model; @@ -2555,8 +2577,8 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock, - params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { + params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock, + params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { delete model; fprintf(stderr, "%s: failed to load model\n", __func__); return nullptr; @@ -2570,8 +2592,8 @@ void llama_free_model(struct llama_model * model) { } struct llama_context * llama_new_context_with_model( - struct llama_model * model, - struct llama_context_params params) { + struct llama_model * model, + struct llama_context_params params) { if (!model) { return nullptr; @@ -2679,8 +2701,8 @@ struct llama_context * llama_new_context_with_model( } struct llama_context * llama_init_from_file( - const char * path_model, - struct llama_context_params params) { + const char * path_model, + struct llama_context_params params) { struct llama_model * model = llama_load_model_from_file(path_model, params); if (!model) { @@ -2797,6 +2819,9 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const // read tensors and apply bool warned = false; int n_tensors = 0; + + std::vector work_buffer; + while (true) { int32_t n_dims; int32_t length; @@ -2845,11 +2870,11 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const case 0: wtype = GGML_TYPE_F32; break; case 1: wtype = GGML_TYPE_F16; break; default: - { - fprintf(stderr, "%s: invalid tensor data type '%d'\n", - __func__, ftype); - return false; - } + { + fprintf(stderr, "%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } } ggml_tensor * lora_tensor; if (n_dims == 2) { @@ -2926,7 +2951,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" - " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); + " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); return 1; } @@ -3035,16 +3060,16 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv = ctx->kv_self.buf.size; const size_t s_total = ( - + s_rng_size - + s_rng - + s_logits_capacity - + s_logits_size - + s_logits - + s_embedding_size - + s_embedding - + s_kv_size - + s_kv_ntok - + s_kv + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv ); return s_total; @@ -3125,12 +3150,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { out += ggml_nbytes(vout3d); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, - elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, - elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); @@ -3230,12 +3255,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { inp += ggml_nbytes(vin3d); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, - elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, - elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); @@ -3255,7 +3280,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { return nread; } -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -3309,6 +3334,15 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi return true; } +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + fprintf(stderr, "error loading session file: %s\n", err.what()); + return false; + } +} + bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); @@ -3336,10 +3370,10 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi int llama_eval( struct llama_context * ctx, - const llama_token * tokens, - int n_tokens, - int n_past, - int n_threads) { + const llama_token * tokens, + int n_tokens, + int n_past, + int n_threads) { if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; @@ -3357,11 +3391,11 @@ int llama_eval( int llama_eval_embd( - struct llama_context * ctx, - const float * embd, - int n_tokens, - int n_past, - int n_threads) { + struct llama_context * ctx, + const float * embd, + int n_tokens, + int n_past, + int n_threads) { if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; @@ -3393,10 +3427,10 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { int llama_tokenize( struct llama_context * ctx, - const char * text, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { auto res = llama_tokenize(ctx->vocab, text, add_bos); if (n_max_tokens < (int) res.size()) { @@ -3464,13 +3498,25 @@ llama_token llama_token_nl() { return 13; } +struct llama_timings llama_get_timings(struct llama_context * ctx) { + struct llama_timings result = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, -void llama_print_timings(struct llama_context * ctx) { - const int64_t t_end_us = ggml_time_us(); + /*.n_sample =*/ std::max(1, ctx->n_sample), + /*.n_p_eval =*/ std::max(1, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), + }; + + return result; +} - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_eval = std::max(1, ctx->n_eval); - const int32_t n_p_eval = std::max(1, ctx->n_p_eval); +void llama_print_timings(struct llama_context * ctx) { + const llama_timings timings = llama_get_timings(ctx); fprintf(stderr, "\n"); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, timings.t_load_ms); @@ -3479,8 +3525,8 @@ void llama_print_timings(struct llama_context * ctx) { fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval, 1e6 / ctx->t_eval_us * n_eval); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); + __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); } void llama_reset_timings(struct llama_context * ctx) { @@ -3515,4 +3561,4 @@ const char * llama_print_system_info(void) { // For internal test use const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { return ctx->model.tensors_by_name; -} \ No newline at end of file +} From 172fffeedcebf9d3a654b9be1be8d14446d2f1f6 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 11 Jul 2023 21:41:15 +0800 Subject: [PATCH 011/180] updated llama.cpp build process to build OpenBLAS from source --- CMakeLists.txt | 73 ++++++-------------------------------------------- 1 file changed, 8 insertions(+), 65 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ec683c5ba1f4d..98c6dbcbf4fd9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,71 +163,14 @@ if (LLAMA_BLAS) set(BLA_SIZEOF_INTEGER 8) endif() - set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) - find_package(BLAS) - - if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - - if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_package(PkgConfig REQUIRED) - if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") - pkg_check_modules(DepBLAS REQUIRED blas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") - pkg_check_modules(DepBLAS REQUIRED openblas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") - pkg_check_modules(DepBLAS REQUIRED blis) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") - pkg_check_modules(DepBLAS REQUIRED blas-atlas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") - pkg_check_modules(DepBLAS REQUIRED flexiblas_api) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") - # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS REQUIRED mkl-sdl) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") - # this doesn't provide pkg-config - # suggest to assign BLAS_INCLUDE_DIRS on your own - if ("${NVHPC_VERSION}" STREQUAL "") - message(WARNING "Better to set NVHPC_VERSION") - else() - set(DepBLAS_FOUND ON) - set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") - endif() - endif() - if (DepBLAS_FOUND) - set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" - " detected by pkgconfig, trying to find cblas.h from possible paths...") - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - /opt/homebrew/opt/openblas/include - /usr/local/opt/openblas/include - /usr/include/x86_64-linux-gnu/openblas/include - ) - endif() - endif() - - message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") - add_compile_options(${BLAS_LINKER_FLAGS}) - add_compile_definitions(GGML_USE_OPENBLAS) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) - add_compile_definitions(GGML_BLAS_USE_MKL) - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) - - else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") - endif() + add_subdirectory(../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/openblas) + set(BLAS_INCLUDE_DIRS ../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/openblas) + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + add_compile_options(${BLAS_LINKER_FLAGS}) + add_compile_definitions(GGML_USE_OPENBLAS) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) endif() if (LLAMA_K_QUANTS) From 6657588b20c6340d99532c05feb9f93ae8c306ba Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 17 Jul 2023 12:08:06 +0800 Subject: [PATCH 012/180] updated llama.cpp read file to print unexpected end of file error condition --- llama-util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-util.h b/llama-util.h index 042ebe43c48e1..6cd1814522527 100644 --- a/llama-util.h +++ b/llama-util.h @@ -111,7 +111,7 @@ struct llama_file { throw std::runtime_error(format("read error: %s", strerror(errno))); } if (ret != 1) { - throw std::runtime_error(std::string("unexpectedly reached end of file")); + throw std::runtime_error(std::string(format("unexpectedly reached end of file; ret = %zu", ret))); } } From 02b3ecef7502dd8330107da08e61ef261f06a059 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 31 Jul 2023 12:52:38 +0800 Subject: [PATCH 013/180] added more grammar examples --- grammars/chinese.gbnf | 4 ++++ grammars/korean.gbnf | 8 ++++++++ grammars/schedule.gbnf | 3 +++ 3 files changed, 15 insertions(+) create mode 100644 grammars/chinese.gbnf create mode 100644 grammars/korean.gbnf create mode 100644 grammars/schedule.gbnf diff --git a/grammars/chinese.gbnf b/grammars/chinese.gbnf new file mode 100644 index 0000000000000..e9653f02c3d5f --- /dev/null +++ b/grammars/chinese.gbnf @@ -0,0 +1,4 @@ +root ::= cn-char+ ([ \t\n] cn-char+)* +cn-char ::= cjk | punctuation +cjk ::= [一-鿿] | [𠀀-𯿽] +punctuation ::= [、-〾] diff --git a/grammars/korean.gbnf b/grammars/korean.gbnf new file mode 100644 index 0000000000000..87e9a439b6dad --- /dev/null +++ b/grammars/korean.gbnf @@ -0,0 +1,8 @@ +root ::= conversation+ +conversation ::= assistant-line "\nUSER: " +assistant-line ::= kr-string "\n" +kr-string ::= kr-char* +kr-char ::= hangul | punctuation | whitespace +hangul ::= [가-힣] +punctuation ::= [、-〾] +whitespace ::= [ \t] \ No newline at end of file diff --git a/grammars/schedule.gbnf b/grammars/schedule.gbnf new file mode 100644 index 0000000000000..6784f98dabdff --- /dev/null +++ b/grammars/schedule.gbnf @@ -0,0 +1,3 @@ +root ::= record +record ::= "Event: " string "\n" "Date: " string "\n" "Time: " string +string ::= "" [ -~]* "" \ No newline at end of file From 51c6e5b9180d06654a1fa7910ca9077dda079385 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 1 Aug 2023 13:51:13 +0800 Subject: [PATCH 014/180] updated schedule grammar --- grammars/schedule.gbnf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grammars/schedule.gbnf b/grammars/schedule.gbnf index 6784f98dabdff..4a04c771a1851 100644 --- a/grammars/schedule.gbnf +++ b/grammars/schedule.gbnf @@ -1,3 +1,3 @@ root ::= record -record ::= "Event: " string "\n" "Date: " string "\n" "Time: " string +record ::= "Event: " string "\n" "Date: " string "\n" "Time: " string "\n" string ::= "" [ -~]* "" \ No newline at end of file From 2ece129e35946b2bab5be6ba0e1f19fce2a1b195 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 2 Aug 2023 16:33:26 +0800 Subject: [PATCH 015/180] added feature to stream saving to file instead of allocating --- llama.cpp | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 503bcca7f0895..d67bf3de0cae9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3832,6 +3832,105 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { return written; } + +// writes state data directly to file instead of copying it into a buffer +void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst_file) { + // copy rng + { + std::stringstream rng_ss; + rng_ss << ctx->rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + dst_file->write_raw(&rng_size, sizeof(rng_size)); + dst_file->write_raw(&rng_buf[0], LLAMA_MAX_RNG_STATE); + } + + // copy logits + { + const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + + dst_file->write_raw(&logits_cap, sizeof(logits_cap)); + dst_file->write_raw(&logits_size, sizeof(logits_size)); + + if (logits_size) { + dst_file->write_raw(ctx->logits.data(), logits_size * sizeof(float)); + } + + // If there is a gap between the size and the capacity, write padding + size_t padding_size = (logits_cap - logits_size) * sizeof(float); + if (padding_size > 0) { + std::vector padding(padding_size, 0); // Create a buffer filled with zeros + dst_file->write_raw(padding.data(), padding_size); + } + } + + // copy embeddings + { + const size_t embedding_size = ctx->embedding.size(); + + dst_file->write_raw(&embedding_size, sizeof(embedding_size)); + + if (embedding_size) { + dst_file->write_raw(ctx->embedding.data(), embedding_size * sizeof(float)); + } + } + + // copy kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd_gqa(); + const int n_ctx = hparams.n_ctx; + + const size_t kv_size = kv_self.buf.size; + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + dst_file->write_raw(&kv_size, sizeof(kv_size)); + dst_file->write_raw(&kv_ntok, sizeof(kv_ntok)); + + if (kv_size) { + const size_t elt_size = ggml_element_size(kv_self.k); + + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); + ggml_cgraph gf{}; + + ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + std::vector kout3d_data(ggml_nbytes(kout3d), 0); + kout3d->data = kout3d_data.data(); + + ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + std::vector vout3d_data(ggml_nbytes(vout3d), 0); + vout3d->data = vout3d_data.data(); + + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + + // our data is now in the kout3d_data and vout3d_data buffers + // write them to file + dst_file->write_raw(kout3d_data.data(), kout3d_data.size()); + dst_file->write_raw(vout3d_data.data(), vout3d_data.size()); + } + } +} + // Sets the state reading from the specified source address size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { uint8_t * inp = src; @@ -4016,12 +4115,22 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi // save the context state { - const size_t n_state_size_max = llama_get_state_size(ctx); + // const size_t n_state_size_max = llama_get_state_size(ctx); - std::vector state_data(n_state_size_max); - const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); + // fprintf(stderr, "%s: state size: %zu\n", __func__, n_state_size_max); + + // std::vector state_data(n_state_size_max); + + // fprintf(stderr, "%s: allocated mem for state data\n", __func__); + + // const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); + + // fprintf(stderr, "%s: state size (actual): %zu\n", __func__, n_state_size_cur); + + // file.write_raw(state_data.data(), n_state_size_cur); - file.write_raw(state_data.data(), n_state_size_cur); + // use streaming write instead of copying to a buffer + llama_write_state_data_to_file(ctx, &file); } return true; From 6c798db0413d8701e9986cfde1c76ec888f434a9 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 2 Aug 2023 16:41:25 +0800 Subject: [PATCH 016/180] added stream saving context data to file to avoid allocating unnecessary amounts of memory --- llama.cpp | 109 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index d427054dd8d6a..543d85e2c676c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3841,6 +3841,104 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { return written; } +// writes state data directly to file instead of copying it into a buffer +void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst_file) { + // copy rng + { + std::stringstream rng_ss; + rng_ss << ctx->rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + dst_file->write_raw(&rng_size, sizeof(rng_size)); + dst_file->write_raw(&rng_buf[0], LLAMA_MAX_RNG_STATE); + } + + // copy logits + { + const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + + dst_file->write_raw(&logits_cap, sizeof(logits_cap)); + dst_file->write_raw(&logits_size, sizeof(logits_size)); + + if (logits_size) { + dst_file->write_raw(ctx->logits.data(), logits_size * sizeof(float)); + } + + // If there is a gap between the size and the capacity, write padding + size_t padding_size = (logits_cap - logits_size) * sizeof(float); + if (padding_size > 0) { + std::vector padding(padding_size, 0); // Create a buffer filled with zeros + dst_file->write_raw(padding.data(), padding_size); + } + } + + // copy embeddings + { + const size_t embedding_size = ctx->embedding.size(); + + dst_file->write_raw(&embedding_size, sizeof(embedding_size)); + + if (embedding_size) { + dst_file->write_raw(ctx->embedding.data(), embedding_size * sizeof(float)); + } + } + + // copy kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd_gqa(); + const int n_ctx = hparams.n_ctx; + + const size_t kv_size = kv_self.buf.size; + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + dst_file->write_raw(&kv_size, sizeof(kv_size)); + dst_file->write_raw(&kv_ntok, sizeof(kv_ntok)); + + if (kv_size) { + const size_t elt_size = ggml_element_size(kv_self.k); + + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); + ggml_cgraph gf{}; + + ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + std::vector kout3d_data(ggml_nbytes(kout3d), 0); + kout3d->data = kout3d_data.data(); + + ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + std::vector vout3d_data(ggml_nbytes(vout3d), 0); + vout3d->data = vout3d_data.data(); + + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + + // our data is now in the kout3d_data and vout3d_data buffers + // write them to file + dst_file->write_raw(kout3d_data.data(), kout3d_data.size()); + dst_file->write_raw(vout3d_data.data(), vout3d_data.size()); + } + } +} + // Sets the state reading from the specified source address size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { uint8_t * inp = src; @@ -4023,15 +4121,8 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi file.write_u32((uint32_t) n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count); - // save the context state - { - const size_t n_state_size_max = llama_get_state_size(ctx); - - std::vector state_data(n_state_size_max); - const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); - - file.write_raw(state_data.data(), n_state_size_cur); - } + // save the context state using stream saving + llama_write_state_data_to_file(ctx, &file); return true; } From 2c12d6d0678ac3b11ed370e0b234833ddbb00974 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 2 Aug 2023 17:18:11 +0800 Subject: [PATCH 017/180] removed commented out code --- llama.cpp | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index d67bf3de0cae9..acc30a701e608 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4113,25 +4113,8 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi file.write_u32((uint32_t) n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count); - // save the context state - { - // const size_t n_state_size_max = llama_get_state_size(ctx); - - // fprintf(stderr, "%s: state size: %zu\n", __func__, n_state_size_max); - - // std::vector state_data(n_state_size_max); - - // fprintf(stderr, "%s: allocated mem for state data\n", __func__); - - // const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); - - // fprintf(stderr, "%s: state size (actual): %zu\n", __func__, n_state_size_cur); - - // file.write_raw(state_data.data(), n_state_size_cur); - - // use streaming write instead of copying to a buffer - llama_write_state_data_to_file(ctx, &file); - } + // use streaming write instead of copying to a buffer + llama_write_state_data_to_file(ctx, &file); return true; } @@ -4362,4 +4345,4 @@ const char * llama_print_system_info(void) { // For internal test use const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { return ctx->model.tensors_by_name; -} \ No newline at end of file +} From 3cebd6e4b70c7b976e1c2e7982972d9b2ee837ff Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 3 Aug 2023 17:02:29 +0800 Subject: [PATCH 018/180] generalised copying state data to file or buffer --- llama-util.h | 23 +++++++++ llama.cpp | 129 +++++++-------------------------------------------- 2 files changed, 39 insertions(+), 113 deletions(-) diff --git a/llama-util.h b/llama-util.h index 042ebe43c48e1..da1cdf8d5cf33 100644 --- a/llama-util.h +++ b/llama-util.h @@ -149,6 +149,29 @@ struct llama_file { } }; +// llama_context_data +struct llama_data_context { + virtual void write(const void* src, size_t size) = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t* ptr; + llama_data_buffer_context(uint8_t* p) : ptr(p) {} + void write(const void* src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file* file; + llama_data_file_context(llama_file* f) : file(f) {} + void write(const void* src, size_t size) override { + file->write_raw(src, size); + } +}; + #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { LPSTR buf; diff --git a/llama.cpp b/llama.cpp index 543d85e2c676c..83baa941b808f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3743,106 +3743,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) { return s_total; } -// Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { - uint8_t * out = dst; - - // copy rng - { - std::stringstream rng_ss; - rng_ss << ctx->rng; - - const size_t rng_size = rng_ss.str().size(); - char rng_buf[LLAMA_MAX_RNG_STATE]; - - memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); - memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - - memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); - memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; - } - - // copy logits - { - const size_t logits_cap = ctx->logits.capacity(); - const size_t logits_size = ctx->logits.size(); - - memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap); - memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); - - if (logits_size) { - memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); - } - - out += logits_cap * sizeof(float); - } - - // copy embeddings - { - const size_t embedding_size = ctx->embedding.size(); - - memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); - - if (embedding_size) { - memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); - out += embedding_size * sizeof(float); - } - } - - // copy kv cache - { - const auto & kv_self = ctx->kv_self; - const auto & hparams = ctx->model.hparams; - const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; - - const size_t kv_size = kv_self.buf.size; - const int kv_ntok = llama_get_kv_cache_token_count(ctx); - - memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); - memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); - - if (kv_size) { - const size_t elt_size = ggml_element_size(kv_self.k); - - ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; - - ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); - kout3d->data = out; - out += ggml_nbytes(kout3d); - - ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); - vout3d->data = out; - out += ggml_nbytes(vout3d); - - ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, - elt_size*n_embd, elt_size*n_embd*n_ctx, 0); - - ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, - elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); - - ggml_free(cpy_ctx); - } - } - - const size_t written = out - dst; - const size_t max_size = llama_get_state_size(ctx); - - LLAMA_ASSERT(written <= max_size); - - return written; -} - -// writes state data directly to file instead of copying it into a buffer -void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst_file) { +// copy state data into either a buffer or file depending on the passed in context +void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { std::stringstream rng_ss; @@ -3854,8 +3756,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - dst_file->write_raw(&rng_size, sizeof(rng_size)); - dst_file->write_raw(&rng_buf[0], LLAMA_MAX_RNG_STATE); + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); } // copy logits @@ -3863,18 +3765,18 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst const size_t logits_cap = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); - dst_file->write_raw(&logits_cap, sizeof(logits_cap)); - dst_file->write_raw(&logits_size, sizeof(logits_size)); + data_ctx->write(&logits_cap, sizeof(logits_cap)); + data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) { - dst_file->write_raw(ctx->logits.data(), logits_size * sizeof(float)); + data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); } // If there is a gap between the size and the capacity, write padding size_t padding_size = (logits_cap - logits_size) * sizeof(float); if (padding_size > 0) { std::vector padding(padding_size, 0); // Create a buffer filled with zeros - dst_file->write_raw(padding.data(), padding_size); + data_ctx->write(padding.data(), padding_size); } } @@ -3882,10 +3784,10 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst { const size_t embedding_size = ctx->embedding.size(); - dst_file->write_raw(&embedding_size, sizeof(embedding_size)); + data_ctx->write(&embedding_size, sizeof(embedding_size)); if (embedding_size) { - dst_file->write_raw(ctx->embedding.data(), embedding_size * sizeof(float)); + data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); } } @@ -3900,8 +3802,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst const size_t kv_size = kv_self.buf.size; const int kv_ntok = llama_get_kv_cache_token_count(ctx); - dst_file->write_raw(&kv_size, sizeof(kv_size)); - dst_file->write_raw(&kv_ntok, sizeof(kv_ntok)); + data_ctx->write(&kv_size, sizeof(kv_size)); + data_ctx->write(&kv_ntok, sizeof(kv_ntok)); if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); @@ -3933,8 +3835,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst // our data is now in the kout3d_data and vout3d_data buffers // write them to file - dst_file->write_raw(kout3d_data.data(), kout3d_data.size()); - dst_file->write_raw(vout3d_data.data(), vout3d_data.size()); + data_ctx->write(kout3d_data.data(), kout3d_data.size()); + data_ctx->write(vout3d_data.data(), vout3d_data.size()); } } } @@ -4122,7 +4024,8 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_write_state_data_to_file(ctx, &file); + llama_data_file_context data_ctx(&file); + llama_copy_state_data(ctx, &data_ctx); return true; } From 186c9a0c411de4c825050923d9f9e1f3605c600f Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 8 Aug 2023 19:58:57 +0800 Subject: [PATCH 019/180] ignored lora folder --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c1ab6bb6d08a3..3f6b70f40f21f 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ build-sanitize-thread/ out/ tmp/ +loras/* models/* models-mnt From 726d68081fa081c75ace02bea4e48f350a4dd6fb Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Fri, 18 Aug 2023 12:22:56 +0800 Subject: [PATCH 020/180] edit main to show token id --- examples/main/main.cpp | 1 + prompts/chat-with-layla.txt | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 prompts/chat-with-layla.txt diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 56ada7e69d99d..b0fac867f6319 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -659,6 +659,7 @@ int main(int argc, char ** argv) { if (input_echo) { for (auto id : embd) { printf("%s", llama_token_to_str(ctx, id)); + printf("(%d)", id); } fflush(stdout); } diff --git a/prompts/chat-with-layla.txt b/prompts/chat-with-layla.txt new file mode 100644 index 0000000000000..edcf9ae2542c2 --- /dev/null +++ b/prompts/chat-with-layla.txt @@ -0,0 +1,5 @@ +Layla is an AI created by Layla Network that is helpful, polite, and to the point. She is here to help User with everyday tasks. Layla's favourite animal is the butterfly because it represents transformation, growth, and beauty. + +Conversation between User and an Assistant named "Layla": + +Korean User: \ No newline at end of file From 3c6c65b89b9af696f9f8c48b41f15230dbd3b016 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 23 Aug 2023 23:40:32 +0800 Subject: [PATCH 021/180] updated chat with layla prompt template --- prompts/chat-with-layla.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompts/chat-with-layla.txt b/prompts/chat-with-layla.txt index edcf9ae2542c2..fc500ed8b4414 100644 --- a/prompts/chat-with-layla.txt +++ b/prompts/chat-with-layla.txt @@ -2,4 +2,4 @@ Layla is an AI created by Layla Network that is helpful, polite, and to the poin Conversation between User and an Assistant named "Layla": -Korean User: \ No newline at end of file +User: \ No newline at end of file From f15f3e308e86820c2c7f86ec6cdbc94296062573 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 24 Aug 2023 00:11:23 +0800 Subject: [PATCH 022/180] added support for eot token and unknown token (for an improperly tokenised layla fine-tune) --- examples/main/main.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c63abaab45827..6d6907dd4905b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -54,6 +54,15 @@ void sigint_handler(int signo) { } #endif + +llama_token llama_token_unk() { + return 0; +} + +llama_token llama_token_eot() { + return 32000; +} + int main(int argc, char ** argv) { gpt_params params; @@ -705,7 +714,7 @@ int main(int argc, char ** argv) { } // deal with end of text token in interactive mode - if (last_n_tokens.back() == llama_token_eos()) { + if (last_n_tokens.back() == llama_token_eos() || last_n_tokens.back() == llama_token_eot() || last_n_tokens.back() == llama_token_unk()) { if (params.interactive) { if (params.antiprompt.size() != 0) { // tokenize and inject first reverse prompt From f4f5790711952e4caac270955d8cb8318fd4f90a Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 3 Oct 2023 03:17:44 +0800 Subject: [PATCH 023/180] fixed small bugs in llama.cpp --- llama.cpp | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/llama.cpp b/llama.cpp index bff17135b985f..9bbd2c9d1d09a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7058,14 +7058,14 @@ struct llama_data_file_context : llama_data_context { */ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { // TODO: does not support multi-sequence states - { - const auto & kv_self = ctx->kv_self; - for (uint32_t i = 0; i < kv_self.head; ++i) { - GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); - GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); - GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); - } - } + // { + // const auto & kv_self = ctx->kv_self; + // for (uint32_t i = 0; i < kv_self.head; ++i) { + // GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); + // GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); + // GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); + // } + // } // copy rng { @@ -7271,7 +7271,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { } ctx->kv_self.head = kv_ntok; - ctx->kv_self.size = kv_size; + ctx->kv_self.size = n_ctx; } const size_t nread = inp - src; @@ -7298,10 +7298,11 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c llama_hparams session_hparams; file.read_raw(&session_hparams, sizeof(llama_hparams)); - if (session_hparams != ctx->model.hparams) { - LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); - return false; - } + // TODO: need to do floating point comparison imprecisely for norm_eps + //if (session_hparams != ctx->model.hparams) { + // LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); + // return false; + //} } // load the prompt From 0cea5ed583ecf8f23e604fb266b766ee680a8fce Mon Sep 17 00:00:00 2001 From: George Date: Sat, 7 Oct 2023 21:07:47 +0800 Subject: [PATCH 024/180] fixes various bad memory access errors in iOS 17 --- ggml-alloc.c | 54 ++++++++++++++++++++++++---------------------------- ggml.c | 53 ++++++++++++++++++++++++++------------------------- llama.cpp | 26 +++++++++++++++++++++---- 3 files changed, 74 insertions(+), 59 deletions(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index 805759db74fef..66e1eee5b9357 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -291,21 +291,17 @@ void ggml_allocr_reset(struct ggml_allocr * alloc) { struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) { struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); - *alloc = (struct ggml_allocr){ - /*.data = */ data, - /*.size = */ size, - /*.alignment = */ alignment, - /*.n_free_blocks = */ 0, - /*.free_blocks = */ {{0}}, - /*.hash_table = */ {{0}}, - /*.max_size = */ 0, - /*.measure = */ false, - /*.parse_seq = */ {0}, - /*.parse_seq_len = */ 0, -#ifdef GGML_ALLOCATOR_DEBUG - /*.allocated_tensors = */ {0}, -#endif - }; + (*alloc).data = data; + (*alloc).size = size; + (*alloc).alignment = alignment; + (*alloc).n_free_blocks = 0; + (*alloc).max_size = 0; + (*alloc).measure = false; + (*alloc).parse_seq_len = 0; + + memset((*alloc).free_blocks, 0, sizeof((*alloc).free_blocks)); + memset((*alloc).hash_table, 0, sizeof((*alloc).hash_table)); + memset((*alloc).parse_seq, 0, sizeof((*alloc).parse_seq)); ggml_allocr_reset(alloc); @@ -370,22 +366,22 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { size_t size; alloc_measure_vmem(&base_addr, &size); - - *alloc = (struct ggml_allocr){ - /*.data = */ base_addr, - /*.size = */ size, - /*.alignment = */ alignment, - /*.n_free_blocks = */ 0, - /*.free_blocks = */ {{0}}, - /*.hash_table = */ {{0}}, - /*.max_size = */ 0, - /*.measure = */ true, - /*.parse_seq = */ {0}, - /*.parse_seq_len = */ 0, + + (*alloc).data = base_addr; + (*alloc).size = size; + (*alloc).alignment = alignment; + (*alloc).n_free_blocks = 0; + (*alloc).max_size = 0; + (*alloc).measure = true; + (*alloc).parse_seq_len = 0; + + memset((*alloc).free_blocks, 0, sizeof((*alloc).free_blocks)); + memset((*alloc).hash_table, 0, sizeof((*alloc).hash_table)); + memset((*alloc).parse_seq, 0, sizeof((*alloc).parse_seq)); + #ifdef GGML_ALLOCATOR_DEBUG - /*.allocated_tensors = */ {0}, + (*alloc).allocated_tensors = {0}; #endif - }; ggml_allocr_reset(alloc); diff --git a/ggml.c b/ggml.c index 911a63988e027..a288895ff6d08 100644 --- a/ggml.c +++ b/ggml.c @@ -4722,19 +4722,21 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); - - *ctx = (struct ggml_context) { - /*.mem_size =*/ mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), - /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, - /*.no_alloc =*/ params.no_alloc, - /*.no_alloc_save =*/ params.no_alloc, - /*.n_objects =*/ 0, - /*.objects_begin =*/ NULL, - /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, - }; + + ctx = (struct ggml_context *)malloc(sizeof(struct ggml_context)); + + struct ggml_scratch empty_scratch = { 0, 0, NULL }; + + (*ctx).mem_size = mem_size; + (*ctx).mem_buffer = params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size); + (*ctx).mem_buffer_owned = params.mem_buffer ? false : true; + (*ctx).no_alloc = params.no_alloc; + (*ctx).no_alloc_save = params.no_alloc; + (*ctx).n_objects = 0; + (*ctx).objects_begin = NULL; + (*ctx).objects_end = NULL; + (*ctx).scratch = empty_scratch; + (*ctx).scratch_save = empty_scratch; GGML_ASSERT(ctx->mem_buffer != NULL); @@ -18078,19 +18080,18 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE); struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); - - *cgraph = (struct ggml_cgraph) { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, - /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; + + (*cgraph).n_nodes = 0; + (*cgraph).n_leafs = 0; + (*cgraph).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*cgraph).perf_runs = 0; + (*cgraph).perf_cycles = 0; + (*cgraph).perf_time_us = 0; + + memset((*cgraph).nodes, 0, sizeof((*cgraph).nodes)); + memset((*cgraph).grads, 0, sizeof((*cgraph).grads)); + memset((*cgraph).leafs, 0, sizeof((*cgraph).leafs)); + memset((*cgraph).visited_hash_table, 0, sizeof((*cgraph).visited_hash_table)); return cgraph; } diff --git a/llama.cpp b/llama.cpp index d828922b10d9d..b8804b9bf4217 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7666,7 +7666,21 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; + + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` + struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); + + (*gf).n_nodes = 0; + (*gf).n_leafs = 0; + (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*gf).perf_runs = 0; + (*gf).perf_cycles = 0; + (*gf).perf_time_us = 0; + + memset((*gf).nodes, 0, sizeof((*gf).nodes)); + memset((*gf).grads, 0, sizeof((*gf).grads)); + memset((*gf).leafs, 0, sizeof((*gf).leafs)); + memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); std::vector kout3d_data(ggml_nbytes(kout3d), 0); @@ -7684,9 +7698,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); ggml_free(cpy_ctx); @@ -7694,6 +7708,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // write them to file data_ctx->write(kout3d_data.data(), kout3d_data.size()); data_ctx->write(vout3d_data.data(), vout3d_data.size()); + + // free our allocated graph + free(gf); + gf = NULL; } for (uint32_t i = 0; i < kv_size; ++i) { From 256991561d4918b8a142a8bd33e44504e9d4b3c8 Mon Sep 17 00:00:00 2001 From: George Date: Sat, 7 Oct 2023 21:35:52 +0800 Subject: [PATCH 025/180] fixed load session --- llama.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index b8804b9bf4217..b2f469454e806 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7812,7 +7812,21 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; + + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` + struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); + + (*gf).n_nodes = 0; + (*gf).n_leafs = 0; + (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*gf).perf_runs = 0; + (*gf).perf_cycles = 0; + (*gf).perf_time_us = 0; + + memset((*gf).nodes, 0, sizeof((*gf).nodes)); + memset((*gf).grads, 0, sizeof((*gf).grads)); + memset((*gf).leafs, 0, sizeof((*gf).leafs)); + memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); kin3d->data = (void *) inp; @@ -7830,9 +7844,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d)); + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); ggml_free(cpy_ctx); } From 5a86739e3d6758852b92f0266e29301ac8274e2e Mon Sep 17 00:00:00 2001 From: George Date: Sat, 7 Oct 2023 22:07:19 +0800 Subject: [PATCH 026/180] do not crash if met with unknown tokens in token to piece function --- llama.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index b2f469454e806..26876dfc8b068 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8166,7 +8166,8 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch buf[0] = llama_token_to_byte(model->vocab, token); return 1; } else { - GGML_ASSERT(false); + // let's not crash the program on unknown tokens + //GGML_ASSERT(false); } break; } @@ -8182,12 +8183,14 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch } else if (llama_is_control_token(model->vocab, token)) { ; } else { - GGML_ASSERT(false); + // let's not crash the program on unknown tokens + //GGML_ASSERT(false); } break; } default: - GGML_ASSERT(false); + // let's not crash the program on unknown tokens + //GGML_ASSERT(false); } } return 0; From f9235f6bcde7119474ff8335f22fa0cc33c3928c Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sat, 7 Oct 2023 22:57:05 +0800 Subject: [PATCH 027/180] crash on unknown tokenizer --- llama.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 26876dfc8b068..40e4ac586e7f3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7666,10 +7666,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); - + (*gf).n_nodes = 0; (*gf).n_leafs = 0; (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; @@ -7708,7 +7708,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // write them to file data_ctx->write(kout3d_data.data(), kout3d_data.size()); data_ctx->write(vout3d_data.data(), vout3d_data.size()); - + // free our allocated graph free(gf); gf = NULL; @@ -7812,10 +7812,10 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); - + (*gf).n_nodes = 0; (*gf).n_leafs = 0; (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; @@ -8191,6 +8191,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch default: // let's not crash the program on unknown tokens //GGML_ASSERT(false); + break; } } return 0; From 54e001b70ab6be0929493af93bcb64ccb223f31f Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sun, 8 Oct 2023 00:32:18 +0800 Subject: [PATCH 028/180] fixed logic error when initialising ggml context --- ggml.c | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ggml.c b/ggml.c index a288895ff6d08..6864bd161c0b8 100644 --- a/ggml.c +++ b/ggml.c @@ -4723,20 +4723,18 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); - ctx = (struct ggml_context *)malloc(sizeof(struct ggml_context)); - struct ggml_scratch empty_scratch = { 0, 0, NULL }; - (*ctx).mem_size = mem_size; - (*ctx).mem_buffer = params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size); - (*ctx).mem_buffer_owned = params.mem_buffer ? false : true; - (*ctx).no_alloc = params.no_alloc; - (*ctx).no_alloc_save = params.no_alloc; - (*ctx).n_objects = 0; - (*ctx).objects_begin = NULL; - (*ctx).objects_end = NULL; - (*ctx).scratch = empty_scratch; - (*ctx).scratch_save = empty_scratch; + ctx->mem_size = mem_size; + ctx->mem_buffer = params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size); + ctx->mem_buffer_owned = params.mem_buffer ? false : true; + ctx->no_alloc = params.no_alloc; + ctx->no_alloc_save = params.no_alloc; + ctx->n_objects = 0; + ctx->objects_begin = NULL; + ctx->objects_end = NULL; + ctx->scratch = empty_scratch; + ctx->scratch_save = empty_scratch; GGML_ASSERT(ctx->mem_buffer != NULL); From 6b20c7f1b15776cd1e240e55f331a4ded3087a9d Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sun, 8 Oct 2023 01:14:09 +0800 Subject: [PATCH 029/180] uncommented hparam comparison check after fix --- llama.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2e868cfd79c8d..34b99e36e2c91 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8454,11 +8454,10 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c llama_hparams session_hparams; file.read_raw(&session_hparams, sizeof(llama_hparams)); - // TODO: need to do floating point comparison imprecisely for norm_eps - //if (session_hparams != ctx->model.hparams) { - // LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); - // return false; - //} + if (session_hparams != ctx->model.hparams) { + LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); + return false; + } } // load the prompt From cba720e792357dc972496fdd2b063f2c761c2332 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sun, 8 Oct 2023 20:42:57 +0800 Subject: [PATCH 030/180] added ios toolchain file --- ios.toolchain.cmake | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 ios.toolchain.cmake diff --git a/ios.toolchain.cmake b/ios.toolchain.cmake new file mode 100644 index 0000000000000..10f3676dfb5e8 --- /dev/null +++ b/ios.toolchain.cmake @@ -0,0 +1,17 @@ +set(CMAKE_SYSTEM_NAME iOS) + +# specify the cross compiler +set(CMAKE_XCODE_ATTRIBUTE_ONLY_ACTIVE_ARCH NO) + +# specify which architectures to build for +set(CMAKE_OSX_ARCHITECTURES "$(ARCHS_STANDARD)") + +# you can also choose to build for a specific device +# set(CMAKE_OSX_ARCHITECTURES "arm64") +# or for the simulator +# set(CMAKE_OSX_ARCHITECTURES "x86_64") + +set(CMAKE_XCODE_EFFECTIVE_PLATFORMS "-iphoneos;-iphonesimulator") + +# you might also want to set the deployment target +# set(CMAKE_XCODE_ATTRIBUTE_IPHONEOS_DEPLOYMENT_TARGET "10.0") From 38b74670829309e4c62f909d73916cac7657d9e3 Mon Sep 17 00:00:00 2001 From: George Date: Sun, 8 Oct 2023 21:53:22 +0800 Subject: [PATCH 031/180] updated cmakelists --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 28093881a29d3..46a92e6905c65 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -168,8 +168,8 @@ if (APPLE AND LLAMA_ACCELERATE) message(STATUS "Accelerate framework found") add_compile_definitions(GGML_USE_ACCELERATE) - add_compile_definitions(ACCELERATE_NEW_LAPACK) - add_compile_definitions(ACCELERATE_LAPACK_ILP64) + # add_compile_definitions(ACCELERATE_NEW_LAPACK) + # add_compile_definitions(ACCELERATE_LAPACK_ILP64) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) else() message(WARNING "Accelerate framework not found") From 4e6db1f62b0cfad87a45ea69354e09fae84156da Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 12 Oct 2023 14:12:26 +0800 Subject: [PATCH 032/180] reverted memory fixes, see https://github.com/ggerganov/llama.cpp/pull/3527 --- ggml-alloc.c | 5 +++-- ggml.c | 55 ++++++++++++++++++++++++------------------------ llama.cpp | 59 ++++++++++++---------------------------------------- 3 files changed, 43 insertions(+), 76 deletions(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index 0fb6c563c9d1a..49657b79760e3 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -317,8 +317,9 @@ struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * bu /*.parse_seq = */ {0}, /*.parse_seq_len = */ 0, #ifdef GGML_ALLOCATOR_DEBUG - (*alloc).allocated_tensors = {0}; + /*.allocated_tensors = */ {0}, #endif + }; ggml_allocr_reset(alloc); @@ -590,4 +591,4 @@ size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * size_t ggml_allocr_max_size(struct ggml_allocr * alloc) { return alloc->max_size; -} +} \ No newline at end of file diff --git a/ggml.c b/ggml.c index ab71c30a7c65b..d16233f12c999 100644 --- a/ggml.c +++ b/ggml.c @@ -4698,21 +4698,19 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); - - ctx = (struct ggml_context *)malloc(sizeof(struct ggml_context)); - - struct ggml_scratch empty_scratch = { 0, 0, NULL }; - - (*ctx).mem_size = mem_size; - (*ctx).mem_buffer = params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size); - (*ctx).mem_buffer_owned = params.mem_buffer ? false : true; - (*ctx).no_alloc = params.no_alloc; - (*ctx).no_alloc_save = params.no_alloc; - (*ctx).n_objects = 0; - (*ctx).objects_begin = NULL; - (*ctx).objects_end = NULL; - (*ctx).scratch = empty_scratch; - (*ctx).scratch_save = empty_scratch; + + *ctx = (struct ggml_context) { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.no_alloc =*/ params.no_alloc, + /*.no_alloc_save =*/ params.no_alloc, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, + }; GGML_ASSERT(ctx->mem_buffer != NULL); @@ -18054,18 +18052,19 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE); struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); - - (*cgraph).n_nodes = 0; - (*cgraph).n_leafs = 0; - (*cgraph).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; - (*cgraph).perf_runs = 0; - (*cgraph).perf_cycles = 0; - (*cgraph).perf_time_us = 0; - - memset((*cgraph).nodes, 0, sizeof((*cgraph).nodes)); - memset((*cgraph).grads, 0, sizeof((*cgraph).grads)); - memset((*cgraph).leafs, 0, sizeof((*cgraph).leafs)); - memset((*cgraph).visited_hash_table, 0, sizeof((*cgraph).visited_hash_table)); + + *cgraph = (struct ggml_cgraph) { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, + /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; return cgraph; } @@ -22005,4 +22004,4 @@ int ggml_cpu_has_vsx(void) { #endif } -//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/llama.cpp b/llama.cpp index cff17028f5bf6..e99705b514dd4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9001,21 +9001,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - - // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` - struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); - - (*gf).n_nodes = 0; - (*gf).n_leafs = 0; - (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; - (*gf).perf_runs = 0; - (*gf).perf_cycles = 0; - (*gf).perf_time_us = 0; - - memset((*gf).nodes, 0, sizeof((*gf).nodes)); - memset((*gf).grads, 0, sizeof((*gf).grads)); - memset((*gf).leafs, 0, sizeof((*gf).leafs)); - memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); + ggml_cgraph gf{}; ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); std::vector kout3d_data(ggml_nbytes(kout3d), 0); @@ -9033,9 +9019,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); ggml_free(cpy_ctx); @@ -9043,10 +9029,6 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // write them to file data_ctx->write(kout3d_data.data(), kout3d_data.size()); data_ctx->write(vout3d_data.data(), vout3d_data.size()); - - // free our allocated graph - free(gf); - gf = NULL; } for (uint32_t i = 0; i < kv_size; ++i) { @@ -9147,21 +9129,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - - // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` - struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); - - (*gf).n_nodes = 0; - (*gf).n_leafs = 0; - (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; - (*gf).perf_runs = 0; - (*gf).perf_cycles = 0; - (*gf).perf_time_us = 0; - - memset((*gf).nodes, 0, sizeof((*gf).nodes)); - memset((*gf).grads, 0, sizeof((*gf).grads)); - memset((*gf).leafs, 0, sizeof((*gf).leafs)); - memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); + ggml_cgraph gf{}; ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); kin3d->data = (void *) inp; @@ -9179,9 +9147,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d)); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); ggml_free(cpy_ctx); } @@ -9233,11 +9201,10 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c llama_hparams session_hparams; file.read_raw(&session_hparams, sizeof(llama_hparams)); - // TODO: need to do floating point comparison imprecisely for norm_eps - //if (session_hparams != ctx->model.hparams) { - // LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); - // return false; - //} + if (session_hparams != ctx->model.hparams) { + LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); + return false; + } } // load the prompt @@ -9662,4 +9629,4 @@ static void llama_log_callback_default(ggml_log_level level, const char * text, (void) user_data; fputs(text, stderr); fflush(stderr); -} +} \ No newline at end of file From 5448c8c50a47b7fbeacbc84a7bf06620e60aab01 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 12 Oct 2023 15:29:39 +0800 Subject: [PATCH 033/180] fixed save_load_session on ios --- llama.cpp | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index e99705b514dd4..95a6fd8fdd47b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9001,7 +9001,21 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; + + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` + struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); + + (*gf).n_nodes = 0; + (*gf).n_leafs = 0; + (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*gf).perf_runs = 0; + (*gf).perf_cycles = 0; + (*gf).perf_time_us = 0; + + memset((*gf).nodes, 0, sizeof((*gf).nodes)); + memset((*gf).grads, 0, sizeof((*gf).grads)); + memset((*gf).leafs, 0, sizeof((*gf).leafs)); + memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); std::vector kout3d_data(ggml_nbytes(kout3d), 0); @@ -9019,9 +9033,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); ggml_free(cpy_ctx); @@ -9129,7 +9143,21 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; + + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` + struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); + + (*gf).n_nodes = 0; + (*gf).n_leafs = 0; + (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*gf).perf_runs = 0; + (*gf).perf_cycles = 0; + (*gf).perf_time_us = 0; + + memset((*gf).nodes, 0, sizeof((*gf).nodes)); + memset((*gf).grads, 0, sizeof((*gf).grads)); + memset((*gf).leafs, 0, sizeof((*gf).leafs)); + memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); kin3d->data = (void *) inp; @@ -9147,9 +9175,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d)); + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); ggml_free(cpy_ctx); } From 48c8cc44ed592b8c08fdbe5f03639998900a0743 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 1 Nov 2023 19:57:49 +0800 Subject: [PATCH 034/180] fixed grammar reset --- common/sampling.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/sampling.cpp b/common/sampling.cpp index 673d67a6d5380..1317024c2c11c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -39,6 +39,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { void llama_sampling_reset(llama_sampling_context * ctx) { if (ctx->grammar != NULL) { llama_grammar_free(ctx->grammar); + ctx->grammar = NULL; } if (!ctx->parsed_grammar.rules.empty()) { From 6105f956a7ad83f189a641a672d1e855b158b7bd Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Fri, 3 Nov 2023 13:38:23 +0800 Subject: [PATCH 035/180] reverted ios memfix --- llama.cpp | 44 ++++++++------------------------------------ 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/llama.cpp b/llama.cpp index dbe5d642a0f0c..cc2f03fd38d5f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8559,21 +8559,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - - // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` - struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); - - (*gf).n_nodes = 0; - (*gf).n_leafs = 0; - (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; - (*gf).perf_runs = 0; - (*gf).perf_cycles = 0; - (*gf).perf_time_us = 0; - - memset((*gf).nodes, 0, sizeof((*gf).nodes)); - memset((*gf).grads, 0, sizeof((*gf).grads)); - memset((*gf).leafs, 0, sizeof((*gf).leafs)); - memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); + ggml_cgraph gf{}; ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); std::vector kout3d_data(ggml_nbytes(kout3d), 0); @@ -8591,9 +8577,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); ggml_free(cpy_ctx); @@ -8701,21 +8687,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - - // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` - struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); - - (*gf).n_nodes = 0; - (*gf).n_leafs = 0; - (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; - (*gf).perf_runs = 0; - (*gf).perf_cycles = 0; - (*gf).perf_time_us = 0; - - memset((*gf).nodes, 0, sizeof((*gf).nodes)); - memset((*gf).grads, 0, sizeof((*gf).grads)); - memset((*gf).leafs, 0, sizeof((*gf).leafs)); - memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); + ggml_cgraph gf{}; ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); kin3d->data = (void *) inp; @@ -8733,9 +8705,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d)); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); ggml_free(cpy_ctx); } From 21e0296a857a31409fe7ff660675c08cdcf3ea3f Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Fri, 3 Nov 2023 15:25:05 +0800 Subject: [PATCH 036/180] Revert "reverted ios memfix" This reverts commit 6105f956a7ad83f189a641a672d1e855b158b7bd. --- llama.cpp | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index cc2f03fd38d5f..dbe5d642a0f0c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8559,7 +8559,21 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; + + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` + struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); + + (*gf).n_nodes = 0; + (*gf).n_leafs = 0; + (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*gf).perf_runs = 0; + (*gf).perf_cycles = 0; + (*gf).perf_time_us = 0; + + memset((*gf).nodes, 0, sizeof((*gf).nodes)); + memset((*gf).grads, 0, sizeof((*gf).grads)); + memset((*gf).leafs, 0, sizeof((*gf).leafs)); + memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); std::vector kout3d_data(ggml_nbytes(kout3d), 0); @@ -8577,9 +8591,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); ggml_free(cpy_ctx); @@ -8687,7 +8701,21 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; + + // create a temporary cgraph without initialising ggml objects, code inspired from `ggml.c:ggml_new_graph` + struct ggml_cgraph * gf = (struct ggml_cgraph *) (malloc(sizeof(ggml_cgraph))); + + (*gf).n_nodes = 0; + (*gf).n_leafs = 0; + (*gf).order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + (*gf).perf_runs = 0; + (*gf).perf_cycles = 0; + (*gf).perf_time_us = 0; + + memset((*gf).nodes, 0, sizeof((*gf).nodes)); + memset((*gf).grads, 0, sizeof((*gf).grads)); + memset((*gf).leafs, 0, sizeof((*gf).leafs)); + memset((*gf).visited_hash_table, 0, sizeof((*gf).visited_hash_table)); ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); kin3d->data = (void *) inp; @@ -8705,9 +8733,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d)); + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); ggml_free(cpy_ctx); } From 15253f0806112711d13183aa823083ab000d9427 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 16 Jan 2024 20:13:06 +0900 Subject: [PATCH 037/180] wip --- llama.cpp | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/llama.cpp b/llama.cpp index 7e4cd477332d2..08723cc24cb4d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8060,6 +8060,72 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } +llama_token llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp, float min_temp = 0, float max_temp = 2.0f) { + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates_p); + + float exponent_val = 1.0f; + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < candidates_p->size; ++i) { + float prob = candidates_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / candidates_p->size); + + // Guard against division by zero + if (max_entropy == 0.0f) { + max_entropy = 1.0f; // This ensures that normalized_entropy will be 0 when entropy is 0 + } + + // Normalize the entropy + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + + // create our probability table + std::vector probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + // Apply the dynamically calculated temperature scaling + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= dyn_temp; + } + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + double max_l_double = candidates_p->data[0].logit; + double cum_sum_double = 0.0; + for (size_t i = 0; i < candidates_p->size; ++i) { + double p = exp(candidates_p->data[i].logit - max_l_double); + candidates_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + + // //todo: Ensure to hide print statements unless debugging! + // // Print the updated top 25 probabilities after temperature scaling + // printf("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + // for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) { + // printf("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f); + // } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { GGML_ASSERT(ctx); From 0787070f5a0ebf2f6ed353f17a11a96b38ee7b08 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 16 Jan 2024 20:36:46 +0900 Subject: [PATCH 038/180] implemented dynamic temperature sampling from Koboldcpp --- common/sampling.cpp | 21 ++++++- common/sampling.h | 1 + llama.cpp | 134 ++++++++++++++++++++++---------------------- llama.h | 8 +++ 4 files changed, 97 insertions(+), 67 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index dd1ffeb1b8083..b2a249ad4d262 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -129,6 +129,7 @@ static void sampler_queue( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const float temp = params.temp; + const float dynatemp_range = params.dynatemp_range; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; const float min_p = params.min_p; @@ -143,7 +144,24 @@ static void sampler_queue( case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case 't': llama_sample_temp (ctx_main, &cur_p, temp); break; + + case 't': + if (dynatemp_range>0) + { + float dynatemp_min = temp - dynatemp_range; + float dynatemp_max = temp + dynatemp_range; + //do not allow negative values + dynatemp_min = dynatemp_min<0?0:dynatemp_min; + dynatemp_max = dynatemp_max<0?0:dynatemp_max; + + llama_sample_entropy(ctx_main, &cur_p, temp, dynatemp_min, dynatemp_max); + } + else + { + llama_sample_temp(ctx_main, &cur_p, temp); + } + + break; default : break; } } @@ -160,6 +178,7 @@ static llama_token llama_sampling_sample_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const float temp = params.temp; + const float dynatemp_range = params.dynatemp_range; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; diff --git a/common/sampling.h b/common/sampling.h index f16ef97e34a10..d95e36433cc4a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -18,6 +18,7 @@ typedef struct llama_sampling_params { float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled + float dynatemp_range = 0.00f; // 0.0 = disabled int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.10f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled diff --git a/llama.cpp b/llama.cpp index 08723cc24cb4d..24ae8ac594b18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7779,6 +7779,74 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c } } +void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp, float min_temp = 0, float max_temp = 2.0f) { + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates_p); + + float exponent_val = 1.0f; + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < candidates_p->size; ++i) { + float prob = candidates_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / candidates_p->size); + + // Guard against division by zero + if (max_entropy == 0.0f) { + max_entropy = 1.0f; // This ensures that normalized_entropy will be 0 when entropy is 0 + } + + // Normalize the entropy + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + + // //todo: Ensure to hide print statements unless debugging! + // printf("Your text maxtemp value is: %f\n", max_temp); + // // Print the variables + // printf("Entropy: %f\n", entropy); + // printf("Max Possible Entropy: %f\n", max_entropy); + // printf("Normalized Entropy: %f\n", normalized_entropy); + // printf("Exponent: %f\n", exponent_val); + // printf("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); + + // Apply the dynamically calculated temperature scaling + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= dyn_temp; + } + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + double max_l_double = candidates_p->data[0].logit; + double cum_sum_double = 0.0; + for (size_t i = 0; i < candidates_p->size; ++i) { + double p = exp(candidates_p->data[i].logit - max_l_double); + candidates_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + + // //todo: Ensure to hide print statements unless debugging! + // // Print the updated top 25 probabilities after temperature scaling + // printf("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + // for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) { + // printf("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f); + // } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { const int64_t t_start_sample_us = ggml_time_us(); @@ -8060,72 +8128,6 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp, float min_temp = 0, float max_temp = 2.0f) { - const int64_t t_start_sample_us = ggml_time_us(); - - llama_sample_softmax(ctx, candidates_p); - - float exponent_val = 1.0f; - - // Calculate entropy of the softmax probabilities - float entropy = 0.0f; - for (size_t i = 0; i < candidates_p->size; ++i) { - float prob = candidates_p->data[i].p; - if (prob > 0.0f) { // Ensure no log(0) - entropy -= prob * logf(prob); - } - } - - // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / candidates_p->size); - - // Guard against division by zero - if (max_entropy == 0.0f) { - max_entropy = 1.0f; // This ensures that normalized_entropy will be 0 when entropy is 0 - } - - // Normalize the entropy - float normalized_entropy = entropy / max_entropy; - - // Map the normalized entropy to the desired temperature range using the power function - float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); - - // create our probability table - std::vector probs; - probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); - } - - // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].logit /= dyn_temp; - } - - // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates_p->data[0].logit; - double cum_sum_double = 0.0; - for (size_t i = 0; i < candidates_p->size; ++i) { - double p = exp(candidates_p->data[i].logit - max_l_double); - candidates_p->data[i].p = p; // Store the scaled probability - cum_sum_double += p; - } - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities - } - - // //todo: Ensure to hide print statements unless debugging! - // // Print the updated top 25 probabilities after temperature scaling - // printf("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - // for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) { - // printf("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f); - // } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { GGML_ASSERT(ctx); diff --git a/llama.h b/llama.h index a570b0d6968fb..37b48fa6cd266 100644 --- a/llama.h +++ b/llama.h @@ -770,6 +770,14 @@ extern "C" { float p, size_t min_keep); + /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. + LLAMA_API void llama_sample_entropy( + struct llama_context * ctx, + llama_token_data_array * candidates_p, + float temp, + float min_temp, + float max_temp); + LLAMA_API void llama_sample_temp( struct llama_context * ctx, llama_token_data_array * candidates, From a1c004ef2e056cdeffcd47aaac196883bb123a3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 17:42:55 +0200 Subject: [PATCH 039/180] ggml : add ggml_flash_attn_ext API --- ggml-metal.m | 50 +++++++ ggml-metal.metal | 29 ++++ ggml.c | 298 ++++++++++++++++++++++++++++++++++++- ggml.h | 9 ++ llama.cpp | 80 +++++----- tests/test-backend-ops.cpp | 28 ++++ 6 files changed, 456 insertions(+), 38 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 912ddc83f7d9c..6d88d5c36a8ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,6 +147,7 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -511,6 +512,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + size_t offs_src2 = 0; + size_t offs_src3 = 0; + + id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 029578dc54dbd..b79a1ba5634a7 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +kernel void kernel_flash_attn_ext_f16( + device const half * q, + device const half * k, + device const half * v, + device const half * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & scale, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // TODO: implement +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index cbf2d4bddddb8..e01d938ceb681 100644 --- a/ggml.c +++ b/ggml.c @@ -1650,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1674,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1736,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1760,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5678,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13212,6 +13254,251 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (mask) { + const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); + ggml_vec_acc_f32(M, S, mp); + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -14717,6 +15004,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); @@ -15713,6 +16004,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -16438,6 +16730,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -16769,6 +17062,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); diff --git a/ggml.h b/ggml.h index de8162b8135f3..d76fe9d5c48c9 100644 --- a/ggml.h +++ b/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1619,6 +1620,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index d28382f7d47b7..cec23c23f1dce 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4205,38 +4205,6 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4246,8 +4214,49 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + // TODO: determine if we can use flash attention + const bool supports_flash_attn = true; + + struct ggml_tensor * kqv; + + if (supports_flash_attn) { + kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + } struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); @@ -9490,8 +9499,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d902c..5693c2197c7c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const ggml_type typeq; + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nt; // tokens + + std::string vars() override { + return VARS_TO_STR5(typeq, hs, nh, kv, nt); + } + + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); From fa7ebcca993ec0d47f6ed6a47a8d5ac4f7407262 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 20:06:26 +0200 Subject: [PATCH 040/180] ggml : fix GQA support in ggml_flash_attn_ext --- ggml-metal.metal | 8 ++++---- ggml.c | 23 +++++++++++++++-------- llama.cpp | 4 ++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b79a1ba5634a7..28847794cb5d8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const half * mask, + device const half * q, + device const half * k, + device const half * v, + device const float * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, diff --git a/ggml.c b/ggml.c index e01d938ceb681..9cf4784ce4759 100644 --- a/ggml.c +++ b/ggml.c @@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + if (params->type == GGML_TASK_INIT) { return; } @@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { for (int64_t ic = 0; ic < nek1; ++ic) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } else { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16(nev0, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), @@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16_unroll(nev0, nbv1, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), diff --git a/llama.cpp b/llama.cpp index cec23c23f1dce..d4bebe5203e9a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv; if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); From a9681febd65cbd3f372badc5f4a4d8bc1336d2d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 12:26:49 +0200 Subject: [PATCH 041/180] ggml : online attention (CPU) --- ggml-metal.m | 8 +- ggml-metal.metal | 3 +- ggml.c | 249 ++++++++++++++++++------------------- ggml.h | 5 + llama.cpp | 124 ++++++++++-------- tests/test-backend-ops.cpp | 14 +-- 6 files changed, 218 insertions(+), 185 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d88d5c36a8ad..4d85dd3ddb319 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + const int nwarps = 4; + + // each warp needs n_embd_head elements + GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 28847794cb5d8..a1e1755a3a605 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1981,7 +1981,8 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & scale, + constant float & scale, + threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { diff --git a/ggml.c b/ggml.c index 9cf4784ce4759..e64a328fadb1f 100644 --- a/ggml.c +++ b/ggml.c @@ -817,7 +817,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); @@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); + GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); @@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + float S = 0.0f; + float M = -INFINITY; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + memset(V16, 0, D*sizeof(ggml_fp16_t)); - // S indices - const int i1 = ik1; + const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; - // S indices - const int i1 = ik1; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? mp[ic] : 0.0f; + if (mv == -INFINITY) { + continue; } - } - // scale - ggml_vec_scale_f32(nek1, S, scale); + float s; - if (mask) { - const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); - ggml_vec_acc_f32(M, S, mp); - } + ggml_vec_dot_f16(D, + &s, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + s = s*scale + mv; - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + const float Mold = M; - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; + float ms = 1.0f; + float vs = 1.0f; - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } + if (s > M) { + M = s; + ms = expf(Mold - M); - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); } - assert(sum > 0.0); + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif + S = S*ms + vs; } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; } - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); } } @@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: - case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index d76fe9d5c48c9..7bca02f2a2c48 100644 --- a/ggml.h +++ b/ggml.h @@ -1620,6 +1620,11 @@ extern "C" { struct ggml_tensor * v, bool masked); + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch, 1, 1] + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index f0a63afef0087..4e6c9f9cc75ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -95,6 +95,8 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN + // // logging // @@ -4167,23 +4169,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -4343,68 +4356,77 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - // split cached v into n_head heads + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); - // TODO: determine if we can use flash attention - const bool supports_flash_attn = true; + cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - struct ggml_tensor * kqv; + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - if (supports_flash_attn) { - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } + // split cached v into n_head heads (transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - } + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif cur = ggml_mul_mat(ctx, wo, cur); if (wo_b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5693c2197c7c5..a56c0d6c59a64 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case { const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size - const int64_t nt; // tokens + const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nt); + return VARS_TO_STR5(typeq, hs, nh, kv, nb); } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH 042/180] metal : initial implementation --- ggml-metal.m | 75 +++++++++++++------- ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 183 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3ddb319..556c53482a75e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; - - const int nwarps = 4; - - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + const int nwarps = 1; + + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3a605..5986bcb427f4b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fadb1f..10df03c9c619b 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c59a64..51a33c662da56 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} From df2a4dca5e13d5ddc0017dfbfec09970992ae976 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Sun, 21 Jan 2024 18:05:33 +0900 Subject: [PATCH 043/180] pointed ggml dependency to layla-build in swift package --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 37524edee8cd4..ede2e55d42ef4 100644 --- a/Package.swift +++ b/Package.swift @@ -14,7 +14,7 @@ let package = Package( .library(name: "llama", targets: ["llama"]), ], dependencies: [ - .package(url: "https://github.com/ggerganov/ggml.git", .branch("release")) + .package(url: "https://github.com/l3utterfly/ggml.git", .branch("layla-build")) ], targets: [ .target( From 528da7515ef874ab1188ab8f691c36d3e9e0cb20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:13:24 +0200 Subject: [PATCH 044/180] metal : f16 precision --- ggml-metal.m | 6 ++++-- ggml-metal.metal | 40 ++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 556c53482a75e..e67a7c4ef892b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2237,8 +2237,10 @@ static bool ggml_metal_graph_compute( const int nwarps = 1; - GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + const size_t shalf = sizeof(float)/2; + + GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5986bcb427f4b..e4e89b5b3f7bf 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1988,7 +1988,7 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne2, constant int64_t & ne3, constant float & scale, - threadgroup float * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], @@ -2003,16 +2003,17 @@ kernel void kernel_flash_attn_ext_f16( } const int64_t D = ne00; + const int64_t D4 = D/4; // TODO: can we move this to the stack? - threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); // initialize with zeros - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] = 0.0h; } - threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); half S = 0.0h; half M = -INFINITY; @@ -2045,8 +2046,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load Q to shared memory - for (int64_t d = 0; d < D; ++d) { - pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + for (int64_t d = 0; d < D4; ++d) { + pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; } for (int64_t ic = 0; ic < ne11; ++ic) { @@ -2055,15 +2056,16 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half s = 0.0f; + half4 s4 = 0.0f; - //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t d = 0; d < D; ++d) { - s += pk[d] * pq[d]; + for (int64_t d = 0; d < D4; ++d) { + s4 += pk4[d] * pq4[d]; } + half s = s4.x + s4.y + s4.z + s4.w; + s = s*scale + mv; const half Mold = M; @@ -2076,24 +2078,24 @@ kernel void kernel_flash_attn_ext_f16( ms = exp(Mold - M); // V = V*exp(Mold - M) - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] *= ms; } } else { vs = exp(s - M); } - device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); // V += v*exp(s - M) - for (int64_t d = 0; d < D; ++d) { - V16[d] += pv[d] * vs; + for (int64_t d = 0; d < D4; ++d) { + V16[d] += pv4[d] * vs; } S = S*ms + vs; } - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] /= S; } @@ -2102,8 +2104,10 @@ kernel void kernel_flash_attn_ext_f16( const int64_t i2 = iq2; const int64_t i3 = iq3; - for (int64_t d = 0; d < D; ++d) { - dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + device float4 * dst4 = (device float4 *) dst; + + for (int64_t d = 0; d < D4; ++d) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; } } From 52ae085750afd37affc4ed18fe092d92c9ccdc5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:38:17 +0200 Subject: [PATCH 045/180] metal : reduce branches --- ggml-metal.metal | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e4e89b5b3f7bf..f3a7efafa6613 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half4 s4 = 0.0f; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + half4 s4 = 0.0h; for (int64_t d = 0; d < D4; ++d) { s4 += pk4[d] * pq4[d]; } - half s = s4.x + s4.y + s4.z + s4.w; - - s = s*scale + mv; + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; const half Mold = M; - half ms = 1.0f; - half vs = 1.0f; - - if (s > M) { - M = s; - ms = exp(Mold - M); - - // V = V*exp(Mold - M) - for (int64_t d = 0; d < D4; ++d) { - V16[d] *= ms; - } - } else { - vs = exp(s - M); - } + M = max(M, s); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - // V += v*exp(s - M) for (int64_t d = 0; d < D4; ++d) { - V16[d] += pv4[d] * vs; + V16[d] = V16[d]*ms + pv4[d]*vs; } S = S*ms + vs; From b97325800a7727244e737715fa7b5e2bc41afb21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:01:55 +0200 Subject: [PATCH 046/180] metal : specialize for head size --- ggml-metal.m | 259 +++++++++++++++++++++++++---------------------- ggml-metal.metal | 42 +++++++- 2 files changed, 179 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e67a7c4ef892b..046643146b3f3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,7 +147,9 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -412,125 +414,127 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -2172,6 +2176,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F16); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; @@ -2202,7 +2207,19 @@ static bool ggml_metal_graph_compute( float scale; memcpy(&scale, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; diff --git a/ggml-metal.metal b/ggml-metal.metal index f3a7efafa6613..d97952f2b0871 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template // head size kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16( return; } - const int64_t D = ne00; const int64_t D4 = D/4; // TODO: can we move this to the stack? @@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( } } +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 8cde449b8be4e481db2a8790d9320c743b3ed65e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:23:22 +0200 Subject: [PATCH 047/180] wip : 8 rows per simd group --- ggml-metal.m | 10 +-- ggml-metal.metal | 173 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 139 insertions(+), 44 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 046643146b3f3..0b1119c4eb467 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b0871..789b19bad6b93 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - if (iq1 >= ne01) { - return; - } + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - const int64_t D4 = D/4; + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; + if (iq2 >= ne02) { + return; } - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; + + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - const half Mold = M; + const half Mold = M; - M = max(M, s); + M = max(M, s); - const half ms = exp(Mold - M); - const half vs = exp(s - M); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } } From f31955f5d12da67f35aa459996a171975fdf269b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:01:28 +0200 Subject: [PATCH 048/180] wip : 4 rows per simd group --- ggml-metal.m | 6 +++--- ggml-metal.metal | 39 +++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b1119c4eb467..abb96d6ec6e44 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 2; + const int64_t nwarps = 4; - const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 789b19bad6b93..6fdd7fdad4326 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2038,7 +2038,7 @@ kernel void kernel_flash_attn_ext_f16( const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2140,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - const uint tiih = tiisg%4; // thread index in head - const uint hiisg = tiisg/4; // head index in simdgroup + const uint tiih = tiisg%8; // thread index in head + const uint hiisg = tiisg/8; // head index in simdgroup // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/4; ++i) { - pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; - ps4[hiisg*D4 + 4*i + tiih] = 0.0h; + for (int64_t i = 0; i < D4/8; ++i) { + pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; + ps4[hiisg*D4 + 8*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,16 +2170,18 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t i = 0; i < D4/4; ++i) { - s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; } - ss4[hiisg*4 + tiih] = s4; + ss4[hiisg*8 + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + + ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2201,8 +2203,9 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; - for (int64_t i = 0; i < D4/4; ++i) { - ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; } } @@ -2223,8 +2226,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/4; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; + for (int64_t i = 0; i < D4/8; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; } } From a4b6341c7b2a1977c29e79b17a0e5de3e31a5420 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:24:13 +0200 Subject: [PATCH 049/180] wip : template for rows per warp --- ggml-metal.m | 7 ++++--- ggml-metal.metal | 54 +++++++++++++++++++++++++----------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index abb96d6ec6e44..d521df43ab302 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; + const int64_t nwarps = 8; + const int64_t nhpw = 4; // heads per warp - const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fdd7fdad4326..c9876c1033f1f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size +template // head size, rows per warp kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2036,9 +2036,10 @@ kernel void kernel_flash_attn_ext_f16( //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; + const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2141,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - const uint tiih = tiisg%8; // thread index in head - const uint hiisg = tiisg/8; // head index in simdgroup + const uint tiih = tiisg%tph; // thread index in head + const uint hiisg = tiisg/tph; // head index in simdgroup - // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/8; ++i) { - pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; - ps4[hiisg*D4 + 8*i + tiih] = 0.0h; + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,18 +2171,20 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*8 + tiih] = s4; + ss4[hiisg*tph + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + - ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; + s4 = 0.0h; + + for (int64_t i = 0; i < tph; ++i) { + s4 += ss4[hiisg*tph + i]; + } half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2203,9 +2206,8 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } } @@ -2226,14 +2228,14 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/8; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; kernel void kernel_cpy_f16_f16( device const half * src0, From 77d08f3272c62900b40d110bf0de7f4466675c71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 21:04:15 +0200 Subject: [PATCH 050/180] metal : parallelize across KV size --- ggml-metal.m | 8 +-- ggml-metal.metal | 137 +++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d521df43ab302..a60dd779a6f09 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 8; - const int64_t nhpw = 4; // heads per warp + const int64_t nwarps = 16; + const int64_t nhptg = 4; // heads per threadgroup - const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); + const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index c9876c1033f1f..539e26c91c34a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per warp +template // head size, rows per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - //const int64_t iq3 = tgpig[2]; - //const int64_t iq2 = tgpig[1]; - //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; + const int64_t iq2 = tgpig[1]*R + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; -// const int64_t D4 = D/4; -// -// // TODO: can we move this to the stack? -// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); -// -// // initialize with zeros -// for (int64_t d = 0; d < D4; ++d) { -// -// } -// -// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); -// -// // load Q to shared memory -// for (int64_t d = 0; d < D4; ++d) { -// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; -// } -// -// half S = 0.0h; -// half M = -INFINITY; -// -// for (int64_t ic = 0; ic < ne11; ++ic) { -// const half mv = mp ? mp[ic] : 0.0h; -// if (mv == -INFINITY) { -// continue; -// } -// -// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); -// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); -// -// half4 s4 = 0.0h; -// -// for (int64_t d = 0; d < D4; ++d) { -// s4 += pk4[d] * pq4[d]; -// } -// -// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; -// -// const half Mold = M; -// -// M = max(M, s); -// -// const half ms = exp(Mold - M); -// const half vs = exp(s - M); -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] = V16[d]*ms + pv4[d]*vs; -// } -// -// S = S*ms + vs; -// } -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] /= S; -// } -// -// // dst indices -// const int64_t i1 = iq1; -// const int64_t i2 = iq2; -// const int64_t i3 = iq3; -// -// device float4 * dst4 = (device float4 *) dst; -// -// for (int64_t d = 0; d < D4; ++d) { -// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; -// } - const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + if (sgitg == 0) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + } + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } - simdgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); half S = 0.0h; half M = -INFINITY; - for (int64_t ic = 0; ic < ne11; ++ic) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { continue; @@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16( s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*tph + tiih] = s4; + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = 0.0h; + half s = 0.0h; for (int64_t i = 0; i < tph; ++i) { - s4 += ss4[hiisg*tph + i]; + s += ss[hiisg*tph + i]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + s = s*scale + mv; const half Mold = M; @@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16( } } - simdgroup_barrier(mem_flags::mem_threadgroup); - if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps + if (sgitg == 0 && tiih == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { + const half S0 = S; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = M; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + } + } + for (int64_t i = 0; i < D4; ++i) { ps4[hiisg*D4 + i] /= S; } @@ -2228,8 +2185,10 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + if (sgitg == 0) { + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + } } } From 17720fad669eed6171ddf17184da5bab50adeb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 22:44:41 +0200 Subject: [PATCH 051/180] metal : parallel reduce across heads --- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a60dd779a6f09..fdfb50d3d03f4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 16; - const int64_t nhptg = 4; // heads per threadgroup + const int64_t nwarps = 32; + const int64_t nhptg = 2; // heads per threadgroup const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 539e26c91c34a..919119c8d55af 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } @@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16( if (tiih == 0) { half s = 0.0h; +#pragma unroll for (int64_t i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } s = s*scale + mv; - const half Mold = M; + const half m = M; M = max(M, s); - const half ms = exp(Mold - M); + const half ms = exp(m - M); const half vs = exp(s - M); S = S*ms + vs; @@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } @@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0 && tiih == 0) { + if (sgitg == 0) { for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = S; + const half S0 = ss[ 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; - const half M0 = M; + const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; } } - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] /= S; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } @@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; kernel void kernel_cpy_f16_f16( device const half * src0, From 925e3d12890f51d16dd2e5581801526063a6d350 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Jan 2024 19:21:38 +0900 Subject: [PATCH 052/180] build from clblast source --- CMakeLists.txt | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b2e5421e0174d..1247d2dd06dcc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -387,19 +387,21 @@ if (LLAMA_MPI) endif() if (LLAMA_CLBLAST) - find_package(CLBlast) - if (CLBlast_FOUND) - message(STATUS "CLBlast found") + # build CLBlast from source + # Add CLBlast directory + add_subdirectory(CLBlast) - set(GGML_HEADERS_OPENCL ggml-opencl.h) - set(GGML_SOURCES_OPENCL ggml-opencl.cpp) + # Include directories for CLBlast + include_directories(CLBlast/include) - add_compile_definitions(GGML_USE_CLBLAST) + message(STATUS "CLBlast source found") - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) - else() - message(WARNING "CLBlast not found") - endif() + set(GGML_HEADERS_OPENCL ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) endif() if (LLAMA_HIPBLAS) @@ -915,4 +917,4 @@ endif () if (LLAMA_BUILD_EXAMPLES) add_subdirectory(examples) add_subdirectory(pocs) -endif() \ No newline at end of file +endif() From 4cca5cca71b311cb4383feda381002a9861bbe58 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Jan 2024 22:19:50 +0900 Subject: [PATCH 053/180] updated llama.cpp to build openblas from source --- CMakeLists.txt | 82 ++++++++++---------------------------------------- 1 file changed, 16 insertions(+), 66 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1247d2dd06dcc..492f2baa2bcfd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,75 +222,25 @@ if (LLAMA_BLAS) set(BLA_SIZEOF_INTEGER 8) endif() - set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) - find_package(BLAS) - - if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - - if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_package(PkgConfig REQUIRED) - if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") - pkg_check_modules(DepBLAS REQUIRED blas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") - # As of openblas v0.3.22, the 64-bit is named openblas64.pc - pkg_check_modules(DepBLAS openblas64) - if (NOT DepBLAS_FOUND) - pkg_check_modules(DepBLAS REQUIRED openblas) - endif() - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") - pkg_check_modules(DepBLAS REQUIRED blis) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") - pkg_check_modules(DepBLAS REQUIRED blas-atlas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") - pkg_check_modules(DepBLAS REQUIRED flexiblas_api) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") - # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS REQUIRED mkl-sdl) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") - # this doesn't provide pkg-config - # suggest to assign BLAS_INCLUDE_DIRS on your own - if ("${NVHPC_VERSION}" STREQUAL "") - message(WARNING "Better to set NVHPC_VERSION") - else() - set(DepBLAS_FOUND ON) - set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") - endif() - endif() - if (DepBLAS_FOUND) - set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" - " detected by pkgconfig, trying to find cblas.h from possible paths...") - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - /opt/homebrew/opt/openblas/include - /usr/local/opt/openblas/include - /usr/include/x86_64-linux-gnu/openblas/include - ) - endif() - endif() + # Include the OpenBLAS directory + add_subdirectory(../OpenBLAS {CMAKE_CURRENT_BINARY_DIR}/OpenBLAS) - message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") - add_compile_options(${BLAS_LINKER_FLAGS}) - add_compile_definitions(GGML_USE_OPENBLAS) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) - add_compile_definitions(GGML_BLAS_USE_MKL) - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) + # Include directories for OpenBLAS + include_directories(../OpenBLAS) - else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") + set(BLAS_LIBRARIES openblas) + set(BLAS_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../OpenBLAS) + + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + + add_compile_options(${BLAS_LINKER_FLAGS}) + add_compile_definitions(GGML_USE_OPENBLAS) + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) endif() if (LLAMA_QKK_64) From 83dd88ad0674e4cfbc1bbbfaf1d8f10b6fb27a6d Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Jan 2024 22:51:49 +0900 Subject: [PATCH 054/180] updated cmake to find clblast in sibling directory --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 492f2baa2bcfd..a3ab09bfd7609 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -339,10 +339,10 @@ endif() if (LLAMA_CLBLAST) # build CLBlast from source # Add CLBlast directory - add_subdirectory(CLBlast) + add_subdirectory(../CLBlast ${CMAKE_CURRENT_BINARY_DIR}/clblast) # Include directories for CLBlast - include_directories(CLBlast/include) + include_directories(../CLBlast/include) message(STATUS "CLBlast source found") From 1446a12b29f422a0c0040e62c16715a3fb7ce1cb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Jan 2024 18:27:54 +0200 Subject: [PATCH 055/180] metal : efficient flash_attn_f16 implementation --- ggml-metal.m | 14 +- ggml-metal.metal | 279 +++++++++++++++++++++++-------------- tests/test-backend-ops.cpp | 6 +- 3 files changed, 188 insertions(+), 111 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index fdfb50d3d03f4..7b161c69d5801 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src3 = gf->nodes[i]->src[3]; GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); size_t offs_src2 = 0; size_t offs_src3 = 0; @@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 32; - const int64_t nhptg = 2; // heads per threadgroup + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) - const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 919119c8d55af..9b6ceec4e1066 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per threadgroup +template // head size, heads per threadgroup, queries per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,178 +2031,247 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*R + tiisg/tph; - const int64_t iq1 = tgpig[0]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*Q; if (iq2 >= ne02) { return; } - // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const int64_t D4 = D/4; + const int64_t N4 = N_SIMDWIDTH; + const int64_t L4 = (D4 + N4 - 1)/N4; + const int64_t D8 = D/8; + + const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 + + threadgroup half * pq = (threadgroup half *) (shared + 0*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); + threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + + for (int64_t i = 0; i < L4; ++i) { + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + if (iq1 + j < ne01) { + pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + } else { + pq4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + // zero out shared memory + for (int64_t j = 0; j < Q; ++j) { + ps4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + if (tiisg < C) { + for (int64_t j = 0; j < Q; ++j) { + ss[j*T + 0 + tiisg] = 0.0h; + } + } - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + threadgroup_barrier(mem_flags::mem_threadgroup); - // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + { + half S[Q] = { 0.0h }; + half M[Q] = { -INFINITY }; - // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; - const int64_t D4 = D/4; + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; - const uint tiih = tiisg%tph; // thread index in head - const uint hiisg = tiisg/tph; // head index in simdgroup + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; - // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { - if (sgitg == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; - } + simdgroup_half8x8 mq[D8]; - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; - } + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[i], pq + i*8, T); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: this can be improved + device const float * mp[Q]; - half S = 0.0h; - half M = -INFINITY; + { + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - for (int64_t ic = sgitg; ic < ne11; ic += nsg) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { - continue; + for (int64_t j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); + } else { + mp[j] = nullptr; + } + } } - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - half4 s4 = 0.0h; + for (int64_t j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; + smc = simd_max(max(smc, mc)); + } -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; - } + if (smc == -INFINITY) { + continue; + } + } + + // Q*K^T + { + simdgroup_half8x8 mk; - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - simdgroup_barrier(mem_flags::mem_threadgroup); + device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - if (tiih == 0) { - half s = 0.0h; + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mk, pk + i*8, nb11/2, 0, true); -#pragma unroll - for (int64_t i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, T, 0, false); + } } - s = s*scale + mv; + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - const half m = M; + const half s = ss[j*T + p]*scale + (mp[j][iic + p]); - M = max(M, s); + half m = M[j]; - const half ms = exp(m - M); - const half vs = exp(s - M); + M[j] = simd_max(max(M[j], s)); - S = S*ms + vs; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; - } + S[j] = S[j]*ms + simd_sum(vs); + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] *= ms; + } + + ss[j*T + p] = vs; + } + + // (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mp[C/8]; + simdgroup_half8x8 mqkv; - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mqkv, ps + i*8, T, 0, false); - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + + for (int cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + + simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + } + + simdgroup_store(mqkv, ps + i*8, T, 0, false); + } + } } - } - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } } threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps + // TODO: try parallel reduce if (sgitg == 0) { + half S = { 0.0h }; + half M = { -INFINITY }; + for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - M = max(M0, M1); + M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + S = S0*ms0 + S1*ms1; - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; - } + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } } } - - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; - } } simdgroup_barrier(mem_flags::mem_threadgroup); - // dst indices - const int64_t i1 = iq1; - const int64_t i2 = iq2; - const int64_t i3 = iq3; - device float4 * dst4 = (device float4 *) dst; if (sgitg == 0) { - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = 0; i < L4; ++i) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51a33c662da56..41ddfcca5b687 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-4; + return 5e-5; } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, @@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From d917746ddb053b73e868fd6e1854ac17b62bd863 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:00:49 +0200 Subject: [PATCH 056/180] metal : avoid redundant loads of the attention --- ggml-metal.metal | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9b6ceec4e1066..785a60e50eba8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } + simdgroup_barrier(mem_flags::mem_none); + // (Q*K^T)*V { simdgroup_half8x8 mv; + simdgroup_half8x8 mp[C/8]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[C/8]; simdgroup_half8x8 mqkv; simdgroup_load(mqkv, ps + i*8, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } - for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); From 432ad04ffaa445a3837b92dce1c03513009ab4ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:47:52 +0200 Subject: [PATCH 057/180] metal : scale and mask in matrix form --- ggml-metal.metal | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 785a60e50eba8..ae8f5caeaa75f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16( } } + // prepare diagonal scale matrix + simdgroup_half8x8 mscale(scale); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { // skip -INF blocks // TODO: double-check this @@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16( device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + // mqk = mqk*scale + mask + simdgroup_float8x8 mm; + simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } @@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + const half s = ss[j*T + p]; half m = M[j]; @@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); } From 40ea8cd1aca61294e1987bcb1051317827f1b145 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 16:31:39 +0200 Subject: [PATCH 058/180] metal : fix comment --- ggml-metal.metal | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ae8f5caeaa75f..9ab9e16c3915a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; const int64_t iq2 = tgpig[1]; const int64_t iq1 = tgpig[0]*Q; - if (iq2 >= ne02) { - return; - } - const int64_t D4 = D/4; const int64_t N4 = N_SIMDWIDTH; const int64_t L4 = (D4 + N4 - 1)/N4; From f9ca5dcbe86a10cfa873814d5f754b7c9108f339 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:46:07 +0200 Subject: [PATCH 059/180] llama : avoid ggml_cast, use F32 query --- ggml-metal.m | 4 ++-- ggml-metal.metal | 3 ++- ggml.c | 31 +++++++++++++++++++++++++++---- ggml.h | 4 ++++ llama.cpp | 3 ++- tests/test-backend-ops.cpp | 16 +++++++--------- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b161c69d5801..7b6762e6d9158 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute( case GGML_OP_FLASH_ATTN_EXT: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9ab9e16c3915a..c9e4dcfe99cd4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; } else { pq4[j*T4 + N4*i + tiisg] = 0.0h; } diff --git a/ggml.c b/ggml.c index 10df03c9c619b..5e515c03fdb9d 100644 --- a/ggml.c +++ b/ggml.c @@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext( return result; } +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); @@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float M = -INFINITY; float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); memset(V16, 0, D*sizeof(ggml_fp16_t)); @@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + ggml_vec_dot_f16(D, &s, (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + Q16); s = s*scale + mv; @@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: { ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { + // TODO: implement F32 precision GGML_ASSERT(false); } break; } diff --git a/ggml.h b/ggml.h index 7bca02f2a2c48..e2f74412fde1e 100644 --- a/ggml.h +++ b/ggml.h @@ -1633,6 +1633,10 @@ extern "C" { struct ggml_tensor * mask, float scale); + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 4e6c9f9cc75ea..550caced4ae57 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 41ddfcca5b687..db1244876ce06 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const ggml_type typeq; const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nb); + return VARS_TO_STR4(hs, nh, kv, nb); } double max_nmse_err() override { return 5e-5; } - test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); @@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 6fea843b246409a3c4b26156745a89e4ba01029b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:59:41 +0200 Subject: [PATCH 060/180] metal : add parallel reduce version (disabled) --- ggml-metal.m | 2 +- ggml-metal.metal | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b6762e6d9158..cf7880c822db5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index c9e4dcfe99cd4..6eb2825df558b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2230,7 +2230,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - // TODO: try parallel reduce +#if 1 if (sgitg == 0) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2261,6 +2261,46 @@ kernel void kernel_flash_attn_ext_f16( } } } +#else + // parallel reduce + // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps + { + half S = { 0.0h }; + half M = { -INFINITY }; + + for (int64_t sg = nsg/2; sg > 0; sg /= 2) { + if (sgitg >= sg) { + continue; + } + + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +#endif simdgroup_barrier(mem_flags::mem_threadgroup); From 77f6976a87f6d034cf0f7a77e14a011da7901911 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 13:15:00 +0200 Subject: [PATCH 061/180] metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments --- ggml-metal.m | 12 +-- ggml-metal.metal | 220 ++++++++++++++++++++++------------------------- 2 files changed, 110 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index eabc16f416645..a7e126bff5318 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,14 +2213,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + const int64_t ncpsg = 32; // cache values per simdgroup + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index 6eb2825df558b..b564f014de2b6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,6 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); +// ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, @@ -2038,39 +2039,45 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iq1 = tgpig[0]*Q; const int64_t D4 = D/4; - const int64_t N4 = N_SIMDWIDTH; - const int64_t L4 = (D4 + N4 - 1)/N4; const int64_t D8 = D/8; + const int64_t NW = N_SIMDWIDTH; + const int64_t L4 = (D4 + NW - 1)/NW; + const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = D + nsg*SH; // shared memory size per query in (half) + const int64_t T4 = T/4; // shared memory size per query in (half4) - threadgroup half * pq = (threadgroup half *) (shared + 0*D); - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); - threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; + sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; } else { - pq4[j*T4 + N4*i + tiisg] = 0.0h; + sq4[j*T4 + NW*i + tiisg] = 0.0h; } } + } - // zero out shared memory - for (int64_t j = 0; j < Q; ++j) { - ps4[j*T4 + N4*i + tiisg] = 0.0h; - } + // zero out lo + for (int64_t i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } + // zero out shared memory SH if (tiisg < C) { for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 0 + tiisg] = 0.0h; + ss[j*T + tiisg] = 0.0h; + if (tiisg < Q) { + ss[j*T + C + tiisg] = 0.0h; + } } } @@ -2103,46 +2110,24 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; + // load the queries from shared memory into local memory simdgroup_half8x8 mq[D8]; for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], pq + i*8, T); + simdgroup_load(mq[i], sq + i*8, T); } - // TODO: this can be improved - device const float * mp[Q]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - { - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int64_t j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); - } else { - mp[j] = nullptr; - } - } - } + // pointer to the mask + device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix simdgroup_half8x8 mscale(scale); - for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; - - for (int64_t j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; - smc = simd_max(max(smc, mc)); - } - - if (smc == -INFINITY) { - continue; - } - } - + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { simdgroup_half8x8 mk; @@ -2150,7 +2135,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); @@ -2160,65 +2145,77 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask simdgroup_float8x8 mm; - simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } + // used to detect blocks full of -INF + half smax = -INFINITY; + // online softmax for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); const half s = ss[j*T + p]; - half m = M[j]; - + smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); + const half m = M[j]; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] *= ms; + // create an 8x8 diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; } + // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - simdgroup_barrier(mem_flags::mem_none); + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } - // (Q*K^T)*V + // O = diag(ms)*O { - simdgroup_half8x8 mv; + simdgroup_half8x8 mm; - simdgroup_half8x8 mp[C/8]; - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } + simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mqkv; + simdgroup_multiply(lo[i], mm, lo[i]); + } + } - simdgroup_load(mqkv, ps + i*8, T, 0, false); + // O = O + (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mp; + simdgroup_load(mp, ss + 8*cc, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D8; ++i) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); } - - simdgroup_store(mqkv, ps + i*8, T, 0, false); } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; @@ -2227,58 +2224,30 @@ kernel void kernel_flash_attn_ext_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps -#if 1 - if (sgitg == 0) { + // reduce the warps sequentially + for (int64_t sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; - for (int64_t sg = 1; sg < nsg; ++sg) { - for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - - M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; - } + // each simdgroup stores its output to shared memory, reusing sq4 + if (sgitg == sg) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } - } -#else - // parallel reduce - // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps - { - half S = { 0.0h }; - half M = { -INFINITY }; - for (int64_t sg = nsg/2; sg > 0; sg /= 2) { - if (sgitg >= sg) { - continue; - } + threadgroup_barrier(mem_flags::mem_threadgroup); + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; M = max(M0, M1); @@ -2290,28 +2259,47 @@ kernel void kernel_flash_attn_ext_f16( if (tiisg == 0) { ss[j*T + 0] = S; ss[j*T + 1] = M; - } - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } - threadgroup_barrier(mem_flags::mem_threadgroup); + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; + + simdgroup_load(ms0, ss + C, T, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } } } -#endif - simdgroup_barrier(mem_flags::mem_threadgroup); + // store result to shared memory (reuse sq4) + if (sgitg == 0) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } device float4 * dst4 = (device float4 *) dst; + // final rescale with 1/S and store to global memory if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; } } } From ecc466a460abc7ad73df3b22a3e0957170bcf7b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 15:42:57 +0200 Subject: [PATCH 062/180] metal : add tests, fix scaling, support C > 32 --- ggml-metal.m | 6 ++-- ggml-metal.metal | 62 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 ++++++--- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a7e126bff5318..484ef89398e7a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index b564f014de2b6..7b604eb61a177 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; const int64_t NW = N_SIMDWIDTH; - const int64_t L4 = (D4 + NW - 1)/NW; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) const int64_t T = D + nsg*SH; // shared memory size per query in (half) @@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16( // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; - for (int64_t i = 0; i < L4; ++i) { - // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; + sq4[j*T4 + i] = (half4) q4[i]; } else { - sq4[j*T4 + NW*i + tiisg] = 0.0h; + sq4[j*T4 + i] = 0.0h; } } } @@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } // zero out shared memory SH - if (tiisg < C) { - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; - if (tiisg < Q) { - ss[j*T + C + tiisg] = 0.0h; - } + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; } } @@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16( // online softmax for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; - - const half s = ss[j*T + p]; + const half m = M[j]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half m = M[j]; + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const half ms = exp(m - M[j]); - S[j] = S[j]*ms + simd_sum(vs); + S[j] = S[j]*ms; // create an 8x8 diagonal matrix for rescaling the output - if (p == j) { + if (tiisg == j) { ss[j*T + C + j] = ms; } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } // skip -INF blocks @@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq4 + // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16( } } - // store result to shared memory (reuse sq4) + // store result to shared memory (reuse sq) if (sgitg == 0) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c98bef7cf3a6..4093a52f2eef1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-5; + return 5e-4; } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) @@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 3a428a10973a751af72b55b9ef396de9c305c6ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 17:47:22 +0200 Subject: [PATCH 063/180] metal : improve precision --- ggml-metal.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7b604eb61a177..b6b5fd997b93a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2120,7 +2120,7 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2163,7 +2163,7 @@ kernel void kernel_flash_attn_ext_f16( M[j] = simd_max(max(M[j], s)); } - const half ms = exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); S[j] = S[j]*ms; @@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = exp(s - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j] + simd_sum(vs); @@ -2255,8 +2255,8 @@ kernel void kernel_flash_attn_ext_f16( M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); S = S0*ms0 + S1*ms1; From 8612864108760897261d0d10101f68355899b03f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 18:10:16 +0200 Subject: [PATCH 064/180] ggml : fix f16 mad --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 6bba840d93d0c..fc0886aecf5a1 100644 --- a/ggml.c +++ b/ggml.c @@ -1344,12 +1344,12 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const // leftovers for (int i = np; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #else // scalar for (int i = 0; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #endif } From 134c81c78dfdeaca988ea2505cc6f0c0aec2d243 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 22:23:40 +0200 Subject: [PATCH 065/180] metal : minor --- ggml-metal.metal | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b6b5fd997b93a..ad6a4a318f4c3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { - simdgroup_half8x8 mk; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } @@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { @@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { - simdgroup_half8x8 mv; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mp; simdgroup_load(mp, ss + 8*cc, T, 0, false); @@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < D8; ++i) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mv; simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); From 1db22d7032fd55a612e400164cb70ad238bbc055 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 23:08:31 +0200 Subject: [PATCH 066/180] metal : support Q > 8 --- examples/batched-bench/batched-bench.cpp | 2 +- ggml-metal.m | 7 ++- ggml-metal.metal | 80 +++++++++++++++--------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 7924db267401c..4992b57f6f9db 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.n_batch = 2048; ctx_params.mul_mat_q = mmq; ctx_params.n_threads = params.n_threads; diff --git a/ggml-metal.m b/ggml-metal.m index ef799ef57b643..a0dd1d0df5bcb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index ad6a4a318f4c3..08c000cc4c027 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; + const int64_t Q8 = Q/8; const int64_t NW = N_SIMDWIDTH; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) @@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { @@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + lo[j][i] = make_filled_simdgroup_matrix(0.0h); + } } // zero out shared memory SH @@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; + simdgroup_half8x8 mq[Q8][D8]; - for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + } } const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; @@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mqk[Q8]; + for (int64_t j = 0; j < Q8; ++j) { + mqk[j] = make_filled_simdgroup_matrix(0.h); + } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + } } // mqk = mqk*scale + mask - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_float8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk, ss + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + } } } @@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms; - // create an 8x8 diagonal matrix for rescaling the output + // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } @@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); + simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mp; - simdgroup_load(mp, ss + 8*cc, T, 0, false); + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); for (int64_t i = 0; i < D8; ++i) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_half8x8 mv; - simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mv; + simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); + simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + } } } } @@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } @@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; - simdgroup_load(ms0, ss + C, T, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); } } } @@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } From 4794821a31d5778b3398b8375d29fa63a539c8c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 16:44:55 +0200 Subject: [PATCH 067/180] tests : add ATTN tests --- tests/test-backend-ops.cpp | 70 +++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c18ff07ea4d21..0ce498e9e7dd4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case { } }; +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); + test_cases.emplace_back(new test_attn(64, 32, 512, 8)); + test_cases.emplace_back(new test_attn(64, 32, 512, 7)); + test_cases.emplace_back(new test_attn(64, 32, 512, 1)); + test_cases.emplace_back(new test_attn(80, 32, 512, 8)); + test_cases.emplace_back(new test_attn(80, 32, 512, 7)); + test_cases.emplace_back(new test_attn(80, 32, 512, 1)); + test_cases.emplace_back(new test_attn(128, 32, 512, 8)); + test_cases.emplace_back(new test_attn(128, 32, 512, 7)); + test_cases.emplace_back(new test_attn(128, 32, 512, 1)); + + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From abeaf0d90ee82096a0aba20785f1e37bd1f3aa41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:12:24 +0200 Subject: [PATCH 068/180] metal : disable buffer allocation logs --- ggml-metal.m | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a0dd1d0df5bcb..a637f04875dbe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2421,10 +2421,13 @@ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buff UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2434,10 +2437,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2471,8 +2479,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2549,7 +2556,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2572,7 +2579,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2581,8 +2589,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } From c6c1132e5e6658b3c209433ed5ef75067ef31a2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:22:28 +0200 Subject: [PATCH 069/180] tests : more --- ggml-metal.m | 9 +++++++++ ggml-metal.metal | 3 +++ ggml.c | 5 ----- tests/test-backend-ops.cpp | 29 ++++++++++------------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a637f04875dbe..4b5fd0bb8fc58 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -137,7 +137,10 @@ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -505,7 +508,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute( switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 08c000cc4c027..be059d78f505a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16( template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/ggml.c b/ggml.c index e8a5fcfa485c1..57271a1ad43e3 100644 --- a/ggml.c +++ b/ggml.c @@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; - const int64_t P = nek1 - N; GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); @@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted @@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ce498e9e7dd4..f57e8ab1a853e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_attn(64, 32, 512, 8)); - test_cases.emplace_back(new test_attn(64, 32, 512, 7)); - test_cases.emplace_back(new test_attn(64, 32, 512, 1)); - test_cases.emplace_back(new test_attn(80, 32, 512, 8)); - test_cases.emplace_back(new test_attn(80, 32, 512, 7)); - test_cases.emplace_back(new test_attn(80, 32, 512, 1)); - test_cases.emplace_back(new test_attn(128, 32, 512, 8)); - test_cases.emplace_back(new test_attn(128, 32, 512, 7)); - test_cases.emplace_back(new test_attn(128, 32, 512, 1)); - - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); + for (int hs : { 64, 80, 96, 112, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 5fcb9c1c5af108056c8ad51fc1995de9d7707d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 19:46:22 +0200 Subject: [PATCH 070/180] metal : faster inner loop for C == 32 --- ggml-metal.metal | 59 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index be059d78f505a..db4c7cfde0037 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2164,34 +2164,59 @@ kernel void kernel_flash_attn_ext_f16( half smax = -INFINITY; // online softmax - for (int64_t j = 0; j < Q; ++j) { - const half m = M[j]; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - for (int64_t p = tiisg; p < C; p += NW) { + const half m = M[j]; const half s = ss[j*T + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms; + S[j] = S[j]*ms + simd_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - S[j] = S[j] + simd_sum(vs); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } } From ffdfd13abe602a721716d7cb15809bae5293b4ce Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 31 Jan 2024 00:54:24 +0900 Subject: [PATCH 071/180] updated sampling api --- common/sampling.cpp | 19 +++++++++++++++++-- common/sampling.h | 7 +++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b2a249ad4d262..69bb6be74b2c4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -36,10 +36,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { delete ctx; } -void llama_sampling_reset(llama_sampling_context * ctx) { +void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) { if (ctx->grammar != NULL) { llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; + ctx->grammar = nullptr; } if (!ctx->parsed_grammar.rules.empty()) { @@ -49,6 +49,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.data(), grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } +} + +void llama_sampling_reset(llama_sampling_context * ctx) { + llama_sampling_reset_grammar(ctx); std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); @@ -331,3 +335,14 @@ void llama_sampling_accept( llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } + + +void llama_sampling_rollback( + struct llama_sampling_context * ctx_sampling, + int rollback_num) { + if(rollback_num > ctx_sampling->prev.size()) { + rollback_num = ctx_sampling->prev.size(); + } + + ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end()); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 78ced2c5fead2..a5277dc0107d6 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -68,6 +68,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ void llama_sampling_free(struct llama_sampling_context * ctx); +// Reset the sampler grammar without resetting the context +void llama_sampling_reset_grammar(struct llama_sampling_context * ctx); + // Reset the sampler context // - clear prev tokens // - reset grammar @@ -116,3 +119,7 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +void llama_sampling_rollback( + struct llama_sampling_context * ctx_sampling, + int rollback_num); From d073e4f93337560e552f0d3de4b6b07bf13ef3f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:45:32 +0200 Subject: [PATCH 072/180] metal : fix array initialization --- ggml-metal.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index db4c7cfde0037..41f6169de8abd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2084,8 +2084,8 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { 0.0h }; - half M[Q] = { -INFINITY }; + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape const int64_t ne22 = ne12; From 78df5527e4e9eafb181200384fbed80c8116042e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:46:49 +0200 Subject: [PATCH 073/180] tests : ifdef --- tests/test-backend-ops.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f57e8ab1a853e..07182c6d8aa63 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,6 +1726,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); +#if 0 for (int hs : { 64, 80, 96, 112, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -1736,6 +1737,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#else + for (int hs : { 128, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, 512 }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#endif #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 8ad92dc1ec9aa6549c68900daa7ab93b57fa3ae5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Jan 2024 19:17:16 +0200 Subject: [PATCH 074/180] ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext --- ggml-cuda.cu | 20 +++++++++---------- ggml-metal.m | 6 ++++++ ggml-metal.metal | 40 ++++++++++++++++++-------------------- ggml.c | 13 +++++++++---- ggml.h | 12 +++++++----- llama.cpp | 40 ++++++++++++++++++++++---------------- tests/test-backend-ops.cpp | 10 +++++----- 7 files changed, 79 insertions(+), 62 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e565957421795..c57a031e4060c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds if (need_check && col_data + 0 >= ncols_data) { val.x = -INFINITY; } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f); } if (need_check && col_data + WARP_SIZE >= ncols_data) { val.y = -INFINITY; } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f); } if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { vals[col_smem] = val; @@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } template -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); } @@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con } } -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max( #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 15e5568f960f1..e00069624551f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + int nth = 32; // SIMD width id pipeline = nil; @@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute( id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index b2e40715d4f2d..04c1aaf9cdfb9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -349,9 +349,9 @@ kernel void kernel_sum_rows( } kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -366,9 +366,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float lmax = -INFINITY; @@ -435,14 +435,14 @@ kernel void kernel_soft_max( } kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +452,15 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +486,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask - device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); + simdgroup_half8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q8; ++j) { - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); diff --git a/ggml.c b/ggml.c index 466a8cdec3c9d..59a4c05a12ffe 100644 --- a/ggml.c +++ b/ggml.c @@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl( bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); @@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } @@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp[i]); + } } #ifndef NDEBUG @@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(V16, 0, D*sizeof(ggml_fp16_t)); - const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? mp[ic] : 0.0f; + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } diff --git a/ggml.h b/ggml.h index a83ff8035f9ea..74ce1abd4d500 100644 --- a/ggml.h +++ b/ggml.h @@ -1646,11 +1646,13 @@ extern "C" { struct ggml_tensor * v, bool masked); - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch, 1, 1] - // res: [n_embd, n_head, n_batch, 1] !! permuted !! +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 1f8ecc19b4e0c..fe25839669efc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4721,7 +4721,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -4905,7 +4905,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5026,7 +5026,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5148,7 +5148,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -5245,7 +5245,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { @@ -5448,7 +5448,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5538,7 +5538,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -5631,7 +5631,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5731,7 +5731,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5854,7 +5854,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5968,7 +5968,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6089,7 +6089,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6211,7 +6211,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6318,7 +6318,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -6416,7 +6416,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6524,7 +6524,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -10250,7 +10250,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - cparams.n_batch = params.n_batch; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -10430,6 +10433,9 @@ struct llama_context * llama_new_context_with_model( ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); + // zero-out the input buffer to prevent NaNs in padded tensors + ggml_backend_buffer_clear(ctx->buf_input, 0); + LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_input), ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0f31c00f9672c..b1b30b91c9c6b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1101,7 +1101,7 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * b = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); return out; } @@ -1472,7 +1472,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } @@ -1506,7 +1506,7 @@ struct test_attn : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); struct ggml_tensor * cur; @@ -1793,7 +1793,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1915,7 +1915,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); From 910b15bb4006409fe24b41da171cc562cdb1f3a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:41:02 +0200 Subject: [PATCH 075/180] ggml : fix ggml_soft_max mask requirement --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 59a4c05a12ffe..ebd9c6b341080 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; From 2e460137490a4e002a60a60aed052e90179bb65b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:47:20 +0200 Subject: [PATCH 076/180] cuda : fix soft_max to use correct mask size --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c57a031e4060c..15fc6154f7508 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9064,7 +9064,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); From 5a19a9f6d0899becbc71a19454a27c0225edddf7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 19:47:11 +0200 Subject: [PATCH 077/180] cuda : add flash_attn kernel (wip) --- ggml-cuda.cu | 735 ++++++++++++++++++++++++++++++++++++++++++++++++++- llama.cpp | 3 +- 2 files changed, 735 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 15fc6154f7508..60d228a61660f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -108,6 +108,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } +static __device__ __forceinline__ half warp_reduce_sum(half x) { +#if __CUDA_ARCH__ >= CC_VOLTA +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x); + } + return x; +#else + (void) x; + NO_DEVICE_CODE; +#endif +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +static __device__ __forceinline__ half warp_reduce_max(half x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr if (lane_id == 0) { s_sum[warp_id] = tmp; } + __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); @@ -6249,6 +6276,528 @@ static __global__ void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } +#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 + +template +static __global__ void flash_attn_f32( + const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + float* __restrict__ kqv, + float kq_scale, + int head_dim, int seq_len, int num_heads) { + const int head = blockIdx.x / seq_len; + const int head_size = head_dim * seq_len; + const int s = blockIdx.x % seq_len; + + extern __shared__ char flash_attn_shmem_f32[]; + float* S = (float*)flash_attn_shmem_f32; + float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float)); + + // QK^T + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + const int key_offset = is * head_dim + head * head_size; + const int query_offset = s * head_dim + head * head_size; + + float tmp = 0.0f; + for(int d = 0; d < head_dim; d++) { + tmp += k[key_offset + d] * q[query_offset + d]; + } + S[is] = tmp * kq_scale; + } + __syncthreads(); + + float max_val = -INFINITY; + // get the max + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + max_val = fmaxf(max_val , S[is]); + } + + max_val = warp_reduce_max(max_val); + + { // get max from all threads + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = max_val; + } + __syncthreads(); + max_val = warp_data[lane_id]; + max_val = warp_reduce_max(max_val); + } + + // softmax(QK^T) + float sum = 0.0f; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + float tmp = expf(S[is] - max_val); + sum += tmp; + S[is] = tmp; + } + __syncthreads(); + + sum = warp_reduce_sum(sum); + { // softmax sum partials + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = sum; + } + __syncthreads(); + sum = warp_data[lane_id]; + sum = warp_reduce_sum(sum); + } + + float inv_sum = 1.0f / sum; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + S[is] *= inv_sum; + } + __syncthreads(); + + // softmax(QK^T)V + #pragma unroll + for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) { + const int d = threadIdx.x + d0; + if(d >= head_dim) { + break; + } + const int dst_index = d + s * head_dim + head * head_size; + const int value_offset = d * seq_len + head * head_size; + + float temp = 0.0f; + #pragma unroll + for(int ic = 0; ic < k_seq_len;ic++) { + if(ic >= seq_len) { + break; + } + temp += v[value_offset + ic] * S[ic]; + } + kqv[dst_index] = temp; + } +} + +#if __CUDA_ARCH__ >= CC_VOLTA +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; +#endif + +// based on metal version +template // D head size, Q queries per block, C cache items per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ dst, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; + + const int num_warps = blockDim.y; // number of warps + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; + + const int D2 = D/2; + const int D16 = D/16; + const int Q16 = Q/16; + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) + + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) + + extern __shared__ half __flash_attn_f16_shmem[]; + // pq + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory + for (int64_t j = warp_id; j < Q; j += num_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = lane_id; i < D2; i += NW) { + if (iq1 + j < ne01) { + sq2[j*T2 + i] = __float22half2_rn(q2[i]); + } else { + sq2[j*T2 + i] = make_half2(0.0, 0.0); + } + } + } + + nvcuda::wmma::fill_fragment(zr, 0.0); + + // zero out lo + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = lane_id; i < SH; i += NW) { + ss[j*T + i] = 0.0; + } + } + + __syncthreads(); + + { + half S[Q]; + half M[Q]; + + for(int i = 0; i < Q; i++) { + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); + } + + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; + + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; + + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + } + } + + // pointer to the mask + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + // Q*K^T + { + for (int cc = 0; cc < C/16; ++cc) { + half16x16_acc mqk[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mqka; + half16x16_acc mm; + if(mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + } + } + } + + // used to detect blocks full of -INF + half smax = __float2half(-INFINITY); + + // online softmax + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; + + const half m = M[j]; + const half s = ss[j*T + p]; + + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + S[j] = S[j]*ms + warp_reduce_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); + } + + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + // local sum + half ls = 0.0f; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + S[j] = S[j]*ms + warp_reduce_sum(ls); + } + } + + // skip -INF blocks + if (__hisinf(smax)) { + continue; + } + + // O = diag(ms)*O + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b lob; + + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int64_t i = 0; i < D16; ++i) { + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + half16x16_b mk[D16]; + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } + + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } + + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (int64_t j = 0; j < Q; ++j) { + if (lane_id == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (int64_t sg = 1; sg < num_warps; ++sg) { + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); + + __syncthreads(); + + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + M = __hmax(M0, M1); + + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (lane_id == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(t2, 0.0); + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, t2); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); + } + } + } + } + + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = lane_id; i < D; i += NW) { + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -7682,6 +8231,13 @@ static void im2col_cuda(const float* x, T* dst, im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } +static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { + int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); + int num_blocks = num_heads * seq_len; + flash_attn_f32<<>>( + q, k, v, dst, kq_scale, d_head, seq_len, num_heads); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -8659,7 +9215,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); + ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -10284,6 +10840,170 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + const int64_t d_head = Q->ne[0]; + const int64_t sequence_length = Q->ne[1]; + const int64_t num_heads = Q->ne[2]; + + GGML_ASSERT(Q->ne[0] == d_head); + GGML_ASSERT(K->ne[0] == d_head); + GGML_ASSERT(V->ne[1] == d_head); + + GGML_ASSERT(Q->ne[1] == sequence_length); + GGML_ASSERT(K->ne[1] == sequence_length); + GGML_ASSERT(V->ne[0] == sequence_length); + + GGML_ASSERT(Q->ne[2] == num_heads); + GGML_ASSERT(K->ne[2] == num_heads); + GGML_ASSERT(V->ne[2] == num_heads); + + float KQ_scale = 1.0f / sqrtf((float)d_head); + + flash_attn_f32_cuda( + (float *) src0_extra->data_device[g_main_device], // Query + (float *) src1_extra->data_device[g_main_device], // Key + (float *) src2_extra->data_device[g_main_device], // Value + (float *) dst_extra->data_device[g_main_device], // dst + KQ_scale, d_head, sequence_length, num_heads, main_stream); +} + + +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + switch (Q->ne[0]) + { + case 16: + flash_attn_ext_f16<16, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10573,6 +11293,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN: + break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -10587,7 +11311,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN) { + ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } @@ -11403,6 +12133,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/llama.cpp b/llama.cpp index fe25839669efc..2330efff57bd3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); From 56e45a239e1d5a871009aa162b7ba99c93c40b62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:16:32 +0200 Subject: [PATCH 078/180] metal : optimize softmax for C > 32 --- ggml-metal.metal | 16 +++++++++++----- tests/test-backend-ops.cpp | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 04c1aaf9cdfb9..3d5d762d16d99 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2217,29 +2217,35 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + smax = max(smax, s); + M[j] = max(M[j], s); } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + smax = simd_max(smax); + M[j] = simd_max(M[j]); - S[j] = S[j]*ms; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0h; + for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j] + simd_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + S[j] = S[j]*ms + simd_sum(ls); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c9c6b..2ab5354069853 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,18 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); From cda5a60a41c669f233a943e1182cd6625f61a924 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:53:29 +0200 Subject: [PATCH 079/180] metal : optimize softmax --- ggml-metal.m | 5 +++-- ggml-metal.metal | 34 +++++++++++++++++++--------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e00069624551f..2bbb6d17a36ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2285,8 +2285,9 @@ static bool ggml_metal_graph_compute( const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d5d762d16d99..d9a536ae87109 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2188,6 +2188,8 @@ kernel void kernel_flash_attn_ext_f16( // online softmax if (C == 32) { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; @@ -2197,20 +2199,22 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } } else { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const half m = M[j]; @@ -2224,12 +2228,7 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; - } + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); // local sum half ls = 0.0h; @@ -2245,7 +2244,12 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } - S[j] = S[j]*ms + simd_sum(ls); + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } } From c6769b942229a9e634965e6215651b8d4cf02403 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 21:24:26 +0200 Subject: [PATCH 080/180] tests : minor fix --- tests/test-backend-ops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2ab5354069853..727f2ea06a5d7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -578,6 +578,7 @@ struct test_case { } #else int n_nodes = gf->n_nodes; + n_runs = 1000; for (int i = 1; i < n_runs; i++) { for (int j = 0; j < n_nodes; j++) { gf->nodes[gf->n_nodes++] = gf->nodes[j]; From db1f3c482e256398330d44ad22b498ca2cd03161 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 22:08:37 +0200 Subject: [PATCH 081/180] cuda : avoid zeroing fragments --- ggml-cuda.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 098b55e073c12..7130209e702d2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; + const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } // restore zeros @@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, t2); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); From 12eaa22628740e388789081ccba93159c1b0b412 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 11:55:38 +0200 Subject: [PATCH 082/180] tests : update dims --- ggml-cuda.cu | 180 ++++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 6 +- 2 files changed, 110 insertions(+), 76 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7130209e702d2..2c050c0c44edc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16( for (int64_t j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if(mp) { + + if (mp) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - switch (Q->ne[0]) - { - case 16: - flash_attn_ext_f16<16, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - default: - break; + switch (Q->ne[0]) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 727f2ea06a5d7..9feb5e1fe550e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,7 +572,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 1 +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From b68a112204c58e2bed334273753211c15acc48e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 15:12:28 +0200 Subject: [PATCH 083/180] cuda : fix __hisinf() result check --- ggml-cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c050c0c44edc..0136fbf28f2a5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { @@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); ls += vs; @@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( } // skip -INF blocks - if (__hisinf(smax)) { + if (__hisinf(smax) == -1) { continue; } @@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; From b150abe83e6f0f8a0cf552d7fc0d8fe9f0f78569 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:17:47 +0200 Subject: [PATCH 084/180] cuda : avoid warp_reduce for smax --- ggml-cuda.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0136fbf28f2a5..c3f24242b350a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16( M[j] = __hmax(M[j], s); } - smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); @@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16( } } + smax = warp_reduce_max(smax); + // skip -INF blocks if (__hisinf(smax) == -1) { continue; From 7c34655b366e14d43f7fc9fa104a9ca7b8f60580 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:39:46 +0200 Subject: [PATCH 085/180] cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) --- ggml-cuda.cu | 70 ++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c3f24242b350a..558ffb8ac7b56 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { + for (int j = warp_id; j < Q; j += num_warps) { const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { + for (int i = lane_id; i < D2; i += NW) { if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i = lane_id; i < SH; i += NW) { ss[j*T + i] = 0.0; } } @@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; @@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16( // online softmax if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half m = M[j]; const half s = ss[j*T + p]; @@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16( ss[j*T + p] = vs; } } else { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); @@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { + for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); half M = __float2half(-INFINITY); @@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); @@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i = lane_id; i < D; i += NW) { dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } From 1f8a5924823aecaa6ab1d5c2ac70ddde1d6c27d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:01:32 +0200 Subject: [PATCH 086/180] cuda : make loops use the same loop values Thanks Johannes again for the tip --- ggml-cuda.cu | 43 +++++++++++++++++++++++++++++++------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558ffb8ac7b56..a3a6c6455017b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += num_warps) { + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int i = lane_id; i < D2; i += NW) { + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16( // zero out shared memory SH for (int j = 0; j < Q; ++j) { - for (int i = lane_id; i < SH; i += NW) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + ss[j*T + i] = 0.0; } } @@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { @@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int i = lane_id; i < D; i += NW) { + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9feb5e1fe550e..e4076b49c180d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 92472ea22ca3eed7f65114b0e6b7de1585930759 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:10:01 +0200 Subject: [PATCH 087/180] cuda : unroll some of the loops --- ggml-cuda.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3a6c6455017b..deda4cc706fdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory +#pragma unroll for (int j0 = 0; j0 < Q; j0 += num_warps) { const int j = j0 + warp_id; if (j >= Q) { @@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16( const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); +#pragma unroll for (int i0 = 0; i0 < D2; i0 += NW) { const int i = i0 + lane_id; if (i >= D2) { From c51f27c0dbd70fe8eda6182d61371d6a2dea6fb9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:27:36 +0200 Subject: [PATCH 088/180] cuda : avoid __hisinf branches --- ggml-cuda.cu | 83 ++++++++++++++++++---------------------------------- 1 file changed, 29 insertions(+), 54 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index deda4cc706fdc..4d1fb008c3994 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16( half S[Q]; half M[Q]; - for(int i = 0; i < Q; i++) { + for (int i = 0; i < Q; ++i) { S[i] = __float2half(0.0f); - M[i] = __float2half(-INFINITY); + M[i] = CUDART_MIN_DENORM_FP16; } // assume K and V are same shape @@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16( half smax = __float2half(-INFINITY); // online softmax - if (C == 32) { - for (int j = 0; j < Q; ++j) { - const int p = lane_id; - - const half m = M[j]; - const half s = ss[j*T + p]; - - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + for (int j = 0; j < Q; ++j) { + const half m = M[j]; - S[j] = S[j]*ms + warp_reduce_sum(vs); + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + const half s = ss[j*T + p]; - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); } - } else { - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - const half s = ss[j*T + p]; + M[j] = warp_reduce_max(M[j]); - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } - - M[j] = warp_reduce_max(M[j]); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = hexp(m - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - - // local sum - half ls = 0.0f; + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; + // local sum + half ls = 0.0f; - const half s = ss[j*T + p]; + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + const half s = ss[j*T + p]; - ls += vs; + const half vs = hexp(s - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } + ls += vs; - S[j] = S[j]*ms + warp_reduce_sum(ls); + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + + S[j] = S[j]*ms + warp_reduce_sum(ls); } smax = warp_reduce_max(smax); @@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); - half M = __float2half(-INFINITY); + half M = CUDART_MIN_DENORM_FP16; __syncthreads(); @@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); S = S0*ms0 + S1*ms1; From b958151e3f66e17a9bc5131e446a50c5529b4b81 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:00:25 +0200 Subject: [PATCH 089/180] cuda : use half2 in softmax --- ggml-cuda.cu | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4d1fb008c3994..1fed9d23e2f47 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16( const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; half16x16_acc lo[Q16][D16]; @@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = __float2half(-INFINITY); + half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); } M[j] = warp_reduce_max(M[j]); @@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16( } // local sum - half ls = 0.0f; + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - const half vs = hexp(s - M[j]); + const half2 vs = h2exp(s - M2); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss2[j*T2 + p] = vs; } - S[j] = S[j]*ms + warp_reduce_sum(ls); + ls = warp_reduce_sum(ls); + + S[j] = S[j]*ms + ls.x + ls.y; } smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax) == -1) { + if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { continue; } From a7b471569bdf4e09e97b2d02c27989b8cb801861 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:17:49 +0200 Subject: [PATCH 090/180] cuda : switch to 1 warp for bs > 16 --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1fed9d23e2f47..c98b551b31e8e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10933,7 +10933,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); From 3b1c4e76739031bee3028748e0cd288c148f77b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:36:05 +0200 Subject: [PATCH 091/180] cuda : speed-up reduce part of the kernel --- ggml-cuda.cu | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c98b551b31e8e..67541a61ef716 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6715,9 +6715,6 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { - half S = __float2half(0.0f); - half M = CUDART_MIN_DENORM_FP16; - __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq @@ -6733,27 +6730,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int j = 0; j < Q; ++j) { + for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; const half M0 = ss[j*T + 1]; const half M1 = ss[j*T + sg*SH + 1]; - M = __hmax(M0, M1); + const half M = __hmax(M0, M1); const half ms0 = hexp(M0 - M); const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 @@ -10931,6 +10926,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) + GGML_ASSERT(NQPB <= 32); + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; From 5b263dd83a5f906eddd10bc044051d7571097043 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:12:20 +0200 Subject: [PATCH 092/180] cuda : unroll Q*K^T loop --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 67541a61ef716..dbd4822396f4e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { +#pragma unroll for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { From e04ff391819e1875beed3e30d9e7592db45e0e62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:57:46 +0200 Subject: [PATCH 093/180] cuda : fix -INF block check --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dbd4822396f4e..e51ddc08f764f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { continue; } @@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } + } - // restore zeros + // restore zeros + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } From cfd9732b2e45a442f4f7261ac0b50ec6e0862ab2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:31:55 +0200 Subject: [PATCH 094/180] cuda : simplify softmax --- ggml-cuda.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e51ddc08f764f..25f810cbea86f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q]; + half S = __float2half(0.0f); half M[Q]; for (int i = 0; i < Q; ++i) { - S[i] = __float2half(0.0f); M[i] = CUDART_MIN_DENORM_FP16; } @@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16( M[j] = warp_reduce_max(M[j]); - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - // local sum half2 ls = make_half2(0.0f, 0.0f); half2 M2 = make_half2(M[j], M[j]); @@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16( ls = warp_reduce_sum(ls); - S[j] = S[j]*ms + ls.x + ls.y; + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } } smax = warp_reduce_max(smax); @@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int j = 0; j < Q; ++j) { - if (lane_id == 0) { - ss[j*T + 0] = S[j]; + if (lane_id == j) { + ss[j*T + 0] = S; ss[j*T + 1] = M[j]; } } From ef68fac2a8b51e2237234e3d7c6120cade457ce8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:36:58 +0200 Subject: [PATCH 095/180] cuda : fix matrix names --- ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 25f810cbea86f..d9ab2bd093feb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mk[D16]; + half16x16_b mv[D16]; for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); } - half16x16_a mv[Q16]; + half16x16_a ms[Q16]; for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } for (int j = 0; j < Q16; ++j) { for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } } } From 1846e92a904ef17a55a3f7e7c2e837f35db2ce7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Feb 2024 09:57:58 +0200 Subject: [PATCH 096/180] cuda : minor --- ggml-cuda.cu | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d9ab2bd093feb..713a6a89acfdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; #endif // based on metal version @@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; + const int C16 = C/16; + const int NW = WARP_SIZE; const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) const int C2 = C/2; + const int D2 = D/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq @@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { #pragma unroll - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); @@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16( // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mv[D16]; @@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int j = 0; j < Q; ++j) { - if (lane_id == j) { - ss[j*T + 0] = S; - ss[j*T + 1] = M[j]; - } + if (lane_id < Q) { + ss[lane_id*T + 0] = S; + ss[lane_id*T + 1] = M[lane_id]; } } @@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + switch (Q->ne[0]) { case 64: flash_attn_ext_f16<64, NQPB, NCPW> @@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * default: break; } + + CUDA_CHECK(cudaGetLastError()); } static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From f249c997a8b3f9b129fe825bebd609a362e9ab9c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 Feb 2024 13:10:24 +0200 Subject: [PATCH 097/180] llama : adapt to F16 KQ_pos --- ggml-cuda.cu | 2 +- ggml.c | 2 +- llama.cpp | 15 ++++++++++----- tests/test-backend-ops.cpp | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c8af51a66d59..5c6159a83f3cd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); diff --git a/ggml.c b/ggml.c index efc570db698f2..9a2ae62647364 100644 --- a/ggml.c +++ b/ggml.c @@ -5192,7 +5192,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { diff --git a/llama.cpp b/llama.cpp index 2359ed10aebf0..5aa3a508d7080 100644 --- a/llama.cpp +++ b/llama.cpp @@ -102,7 +102,7 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 -#define LLAMA_FLASH_ATTN +//#define LLAMA_FLASH_ATTN // // logging @@ -4831,6 +4831,11 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; #if defined(LLAMA_FLASH_ATTN) + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -5260,7 +5265,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); // shift the entire K-cache if needed @@ -5804,7 +5809,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { @@ -6043,7 +6048,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -6140,7 +6145,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 912223def6e06..278c57299ce88 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1505,7 +1505,7 @@ struct test_attn : public test_case { struct ggml_tensor * cur; cur = ggml_mul_mat (ctx, k, q); - cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); cur = ggml_mul_mat (ctx, v, cur); cur = ggml_permute (ctx, cur, 0, 2, 1, 3); cur = ggml_cont_2d (ctx, cur, hs*nh, nb); From f3fea1186199818859adb429404e515a53e4812c Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 26 Feb 2024 16:11:26 +0900 Subject: [PATCH 098/180] added local cblast --- CMakeLists.txt | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 24f1b94fd8e4b..c1864eff5a70c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -438,21 +438,24 @@ if (LLAMA_MPI) endif() if (LLAMA_CLBLAST) - # build CLBlast from source - # Add CLBlast directory - add_subdirectory(../CLBlast ${CMAKE_CURRENT_BINARY_DIR}/clblast) - - # Include directories for CLBlast - include_directories(../CLBlast/include) - - message(STATUS "CLBlast source found") + message(STATUS "Building with CLBlast") set(GGML_HEADERS_OPENCL ggml-opencl.h) set(GGML_SOURCES_OPENCL ggml-opencl.cpp) add_compile_definitions(GGML_USE_CLBLAST) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + # link our libOpenCL.so (this is only used during compile time) + add_library(OpenCL SHARED IMPORTED) + set_target_properties(OpenCL PROPERTIES IMPORTED_LOCATION ${PROJECT_SOURCE_DIR}/../OpenCL/lib/libOpenCL.so) + + # add our prebuilt clblast library + add_library(clblast SHARED IMPORTED) + set_target_properties(clblast PROPERTIES IMPORTED_LOCATION ${PROJECT_SOURCE_DIR}/../../android/app/src/main/jniLibs/${ANDROID_ABI}/libclblast.so) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast OpenCL) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../CLBlast/include) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenCL/include) endif() if (LLAMA_VULKAN) From 1d0fba9831be63d552fdbd8df9d5e2004e011a35 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 26 Feb 2024 19:07:48 +0900 Subject: [PATCH 099/180] updated openblas config --- CMakeLists.txt | 75 +++++--------------------------------------------- 1 file changed, 7 insertions(+), 68 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 26d6842ff175e..89587146a2398 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -261,6 +261,8 @@ if (LLAMA_METAL) ) endif() if (LLAMA_BLAS) + message(STATUS "Building with OpenBLAS") + if (LLAMA_STATIC) set(BLA_STATIC ON) endif() @@ -269,77 +271,14 @@ if (LLAMA_BLAS) endif() set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) - find_package(BLAS) - - if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - - if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_package(PkgConfig REQUIRED) - if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") - pkg_check_modules(DepBLAS REQUIRED blas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") - # As of openblas v0.3.22, the 64-bit is named openblas64.pc - pkg_check_modules(DepBLAS openblas64) - if (NOT DepBLAS_FOUND) - pkg_check_modules(DepBLAS REQUIRED openblas) - endif() - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") - pkg_check_modules(DepBLAS REQUIRED blis) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") - pkg_check_modules(DepBLAS REQUIRED blas-atlas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") - pkg_check_modules(DepBLAS REQUIRED flexiblas_api) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") - # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS REQUIRED mkl-sdl) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") - # this doesn't provide pkg-config - # suggest to assign BLAS_INCLUDE_DIRS on your own - if ("${NVHPC_VERSION}" STREQUAL "") - message(WARNING "Better to set NVHPC_VERSION") - else() - set(DepBLAS_FOUND ON) - set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") - endif() - endif() - if (DepBLAS_FOUND) - set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" - " detected by pkgconfig, trying to find cblas.h from possible paths...") - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - /opt/homebrew/opt/openblas/include - /usr/local/opt/openblas/include - /usr/include/x86_64-linux-gnu/openblas/include - ) - endif() - endif() + add_compile_options(${BLAS_LINKER_FLAGS}) - message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + add_compile_definitions(GGML_USE_OPENBLAS) - add_compile_options(${BLAS_LINKER_FLAGS}) + add_subdirectory(../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS) - add_compile_definitions(GGML_USE_OPENBLAS) - - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) - add_compile_definitions(GGML_BLAS_USE_MKL) - endif() - - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") - endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas_static) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS) endif() if (LLAMA_QKK_64) From 6aefd11204199c9bd520b8991bab4085cb6fc977 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Mar 2024 13:50:54 +0200 Subject: [PATCH 100/180] llama : adapt new models to F16 KQ_mask --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1a099adcba5dc..f2b224cafb2e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7362,7 +7362,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7489,7 +7489,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7724,7 +7724,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { From 58c7f6167c0f6540f0da0386fc65d940e1a16ea5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 20:44:57 +0200 Subject: [PATCH 101/180] ggml : fix F16 store (ARM NEON) --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 09ed9da342885..5715b78ec7bbf 100644 --- a/ggml.c +++ b/ggml.c @@ -874,7 +874,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -900,7 +900,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL From 3a468e6f9f0c7dff9ed78b0f7a5af069da420606 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:12:17 +0200 Subject: [PATCH 102/180] llama : fix type of KQ_mask and KQ_pos --- llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 77da94960e65c..b80080daf7506 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5810,20 +5810,20 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); } struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv); + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); } struct ggml_tensor * build_inp_mean() { From 09532120e0afc5b0c451fbb79eac558e4561660b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:49:42 +0200 Subject: [PATCH 103/180] ggml : fix CPU soft_max --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 434fb76c3b1b6..7ea1abfcd67c6 100644 --- a/ggml.c +++ b/ggml.c @@ -12302,7 +12302,7 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); From e425810bb6c8e9dd8ef9ff6f606313b6cfa1b607 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Mar 2024 12:21:41 +0200 Subject: [PATCH 104/180] tests : add hs=256 --- ggml.c | 2 +- tests/test-backend-ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 7ea1abfcd67c6..336c3c0ffaade 100644 --- a/ggml.c +++ b/ggml.c @@ -12272,7 +12272,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b8994cecff4b2..50c4b27294fea 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2261,7 +2261,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 128, 64, 80, }) { + for (int hs : { 128, 256, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 00c2ccb9aab1f35ae3b8d0427f99bf954b653864 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 26 Mar 2024 15:26:58 +0900 Subject: [PATCH 105/180] use shared openblas --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 838b556cb0ac8..aa7ae20e4949e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,7 +289,7 @@ if (LLAMA_BLAS) add_subdirectory(../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas_static) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas_shared) set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS) endif() From 6be02b5969ed04fb3ce336702ae3949076bb2bf4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 27 Mar 2024 10:31:52 +0200 Subject: [PATCH 106/180] cuda : fix build --- ggml-cuda.cu | 7 +------ ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++++++++++++++++++------ ggml-cuda/fattn.cuh | 5 +---- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 963588c4f2fad..31bfc43c1febe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_argsort(ctx, dst); break; case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); break; default: return false; } - if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } else { - func(ctx, tensor->src[0], tensor->src[1], tensor); - } - cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 0f135a184b572..bcf27fd794aaf 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,5 +1,33 @@ #include "fattn.cuh" +#include + +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + GGML_UNUSED(a); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; @@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment // based on metal version template // D head size, Q queries per block, C cache items per block static __global__ void flash_attn_ext_f16( - const char* __restrict__ q, - const char* __restrict__ k, - const char* __restrict__ v, - const char* __restrict__ mask, - float* __restrict__ dst, + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, float scale, int ne00, int ne01, @@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16( #endif } -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh index 1b764bc946f8f..ad3ca7a8d8e4d 100644 --- a/ggml-cuda/fattn.cuh +++ b/ggml-cuda/fattn.cuh @@ -1,6 +1,3 @@ #include "common.cuh" -void ggml_cuda_flash_attn_ext( - ggml_backend_cuda_context & ctx, - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, - const ggml_tensor * mask, ggml_tensor * KQV); +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 57c03b78b6bca2049e43905e808666b04304d0fd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 19:29:06 +0200 Subject: [PATCH 107/180] metal : improve perf via smaller int registers --- ggml-metal.metal | 149 ++++++++++++++++++++++++----------------------- 1 file changed, 77 insertions(+), 72 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index dc6ad31417c51..27eeb3932ff1a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2061,11 +2061,11 @@ typedef void (flash_attn_ext_f16_t)( constant int64_t & ne3, constant float & scale, threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); // ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup @@ -2099,25 +2099,25 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne3, constant float & scale, threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*Q; - - const int64_t D4 = D/4; - const int64_t D8 = D/8; - const int64_t Q8 = Q/8; - const int64_t NW = N_SIMDWIDTH; - const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - - const int64_t T = D + nsg*SH; // shared memory size per query in (half) - const int64_t T4 = T/4; // shared memory size per query in (half4) + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 @@ -2127,10 +2127,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { + for (short j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { sq4[j*T4 + i] = (half4) q4[i]; } else { @@ -2140,15 +2140,15 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { lo[j][i] = make_filled_simdgroup_matrix(0.0h); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = tiisg; i < SH; i += NW) { + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { ss[j*T + i] = 0.0h; } } @@ -2160,33 +2160,33 @@ kernel void kernel_flash_attn_ext_f16( half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const short ne22 = ne12; + const short ne23 = ne13; - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory simdgroup_half8x8 mq[Q8][D8]; - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); } } @@ -2199,28 +2199,33 @@ kernel void kernel_flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + // Q*K^T { - for (int cc = 0; cc < C/8; ++cc) { + for (short cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk[Q8]; - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { mqk[j] = make_filled_simdgroup_matrix(0.h); } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); @@ -2237,8 +2242,8 @@ kernel void kernel_flash_attn_ext_f16( if (C == 32) { half ms[Q]; - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; + for (short j = 0; j < Q; ++j) { + const short p = tiisg; const half m = M[j]; const half s = ss[j*T + p]; @@ -2262,10 +2267,10 @@ kernel void kernel_flash_attn_ext_f16( } else { half ms[Q]; - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { + for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; smax = max(smax, s); @@ -2280,7 +2285,7 @@ kernel void kernel_flash_attn_ext_f16( // local sum half ls = 0.0h; - for (int64_t p = tiisg; p < C; p += NW) { + for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); @@ -2306,25 +2311,25 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/8; ++cc) { + for (short cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mv; simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); @@ -2336,7 +2341,7 @@ kernel void kernel_flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -2345,7 +2350,7 @@ kernel void kernel_flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < nsg; ++sg) { + for (short sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2353,8 +2358,8 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); } } @@ -2364,7 +2369,7 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -2388,7 +2393,7 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; @@ -2396,7 +2401,7 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); @@ -2408,8 +2413,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); } } @@ -2419,10 +2424,10 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } From 08e69c50081622d8146e749195939157be2a2207 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 19:40:11 +0200 Subject: [PATCH 108/180] cuda : adapt soft_max to F16 mask and pos --- ggml-cuda/softmax.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 9bda18e581c75..8f6dca4d0f9bf 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,7 @@ #include "softmax.cuh" template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f } } -static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; - const float * src1_d = src1 ? (const float *)src1->data : nullptr; + const half * src1_d = src1 ? (const half *)src1->data : nullptr; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = nullptr; + half * src2_dd = nullptr; ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (float *)src2->data; + src2_dd = (half *)src2->data; } soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); From 75aa7b4b189a5a2f6518840e84ec489da71e0443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 29 Mar 2024 23:02:39 +0100 Subject: [PATCH 109/180] CUDA: faster FlashAttention, kernel for bs == 1 --- ggml-cuda/fattn.cu | 1357 +++++++++++++++++++++++++++++--------------- 1 file changed, 906 insertions(+), 451 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index bcf27fd794aaf..ccb3c924609a4 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -28,414 +28,416 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -#if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; -#endif - -// based on metal version -template // D head size, Q queries per block, C cache items per block -static __global__ void flash_attn_ext_f16( - const char * __restrict__ q, - const char * __restrict__ k, - const char * __restrict__ v, +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, - float scale, - int ne00, - int ne01, - int ne02, - int ne03, - int ne10, - int ne11, - int ne12, - int ne13, - int ne31, - int nb31, - int nb01, - int nb02, - int nb03, - int nb11, - int nb12, - int nb13, - int ne0, - int ne1, - int ne2, - int ne3) { -#if __CUDA_ARCH__ >= CC_VOLTA - const int warp_id = threadIdx.y; - const int lane_id = threadIdx.x; - - const int num_warps = blockDim.y; // number of warps - const int iq3 = blockIdx.z; - const int iq2 = blockIdx.y; - const int iq1 = blockIdx.x * Q; - - const int D16 = D/16; - const int Q16 = Q/16; - const int C16 = C/16; - - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) - - const int T = D + num_warps*SH; // shared memory size per query in (half) - const int T2 = T/2; // shared memory size per query in (half2) - const int C2 = C/2; - const int D2 = D/2; - - extern __shared__ half __flash_attn_f16_shmem[]; - // pq - half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data - half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 - half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix - half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 - - half16x16_acc zr; - half16x16_acc lo[Q16][D16]; - - // load heads from Q to shared memory + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne31*blockIdx.x; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + constexpr int nwarps = D/WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[D]; + KQ[tid] = 0.0f; + half2 * KQ2 = (half2 *) KQ; + + half kqmax = -INFINITY; + half kqsum = 0.0f; + + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -INFINITY; + kqsum_shared[threadIdx.x] = 0.0f; + } + + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; #pragma unroll - for (int j0 = 0; j0 < Q; j0 += num_warps) { - const int j = j0 + warp_id; - if (j >= Q) { + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { break; } - const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } + + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; #pragma unroll - for (int i0 = 0; i0 < D2; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D2) { + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { break; } - if (iq1 + j < ne01) { - sq2[j*T2 + i] = __float22half2_rn(q2[i]); - } else { - sq2[j*T2 + i] = make_half2(0.0, 0.0); + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; } - } - } - nvcuda::wmma::fill_fragment(zr, 0.0); + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; + } + } - // zero out lo - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; } - } + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; + + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; + + VKQ *= __half2half2(KQ_max_scale); + + __syncthreads(); - // zero out shared memory SH - for (int j = 0; j < Q; ++j) { - for (int i0 = 0; i0 < SH; i0 += NW) { - const int i = i0 + lane_id; - if (i >= SH) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } - ss[j*T + i] = 0.0; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; } } + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); - { - half S = __float2half(0.0f); - half M[Q]; + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; +} - for (int i = 0; i < Q; ++i) { - M[i] = CUDART_MIN_DENORM_FP16; +template // D == head size +__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c; + + constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; + constexpr int nthreads = nwarps*WARP_SIZE; + static_assert(nthreads % D == 0, "nthreads not divisible by D."); + constexpr int tc_vals_per_iter = nwarps*frag_m; + static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nthreads); + constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + frag_b Q_b[D/16][ncols/frag_n]; + + __shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. + half2 * KQ2 = (half2 *) KQ; + + half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { + const int i = i0 + tid; + if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { + break; } - // assume K and V are same shape - const int ne22 = ne12; - const int ne23 = ne13; - - const int nb21 = nb11; - const int nb22 = nb12; - const int nb23 = nb13; - - // broadcast - const int rk2 = ne02/ne12; - const int rk3 = ne03/ne13; - - const int rv2 = ne02/ne22; - const int rv3 = ne03/ne23; + VKQ2[i] = make_half2(0.0f, 0.0f); + } - // k indices - const int ik2 = iq2 / rk2; - const int ik3 = iq3 / rk3; + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { + const int j = j0 + tid/D; + const int i = tid % D; + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + __syncthreads(); - // load the queries from shared memory into local memory - half16x16_a mq[Q16][D16]; - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); - } + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } + } - // pointer to the mask - const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + __syncthreads(); - // prepare diagonal scale matrix - half16x16_b mscale; - for (int i = 0; i < 16; ++i) { - ss[i*T + i] = __float2half(scale); - } - nvcuda::wmma::load_matrix_sync(mscale, ss, T); + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11; - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { - const int ic = ic0 + warp_id*C; - if (ic >= ne11) { - break; + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + frag_c KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } - - // Q*K^T - { + if (has_valid_data) { #pragma unroll - for (int cc = 0; cc < C16; ++cc) { - half16x16_acc mqk[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::fill_fragment(mqk[j], 0); - } - - const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (int i = 0; i < D16; ++i) { - half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); - } - } - - // mqk = mqk*scale + mask - for (int j = 0; j < Q16; ++j) { - half16x16_a mqka; - half16x16_acc mm; - - if (mp) { - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); - } - - // convert accumulator to matrix_a - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); - - nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); + } + } - // used to detect blocks full of -INF - half2 smax = make_half2(-INFINITY, -INFINITY); - - // online softmax - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; + __syncthreads(); - const half2 s = ss2[j*T2 + p]; + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } - smax = __hmax2(smax, s); - M[j] = __hmax(M[j], __hmax(s.x, s.y)); + half2 KQ_max_new = KQ_max[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + if (k0 + WARP_SIZE > D/2 && k >= D/2) { + break; } + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + KQ_max[j0/nwarps] = KQ_max_new; - M[j] = warp_reduce_max(M[j]); - - // local sum - half2 ls = make_half2(0.0f, 0.0f); - half2 M2 = make_half2(M[j], M[j]); - - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; - - const half2 s = ss2[j*T2 + p]; - - const half2 vs = h2exp(s - M2); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss2[j*T2 + p] = vs; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + if (k0 + WARP_SIZE > D/2 && k >= D/2) { + break; } - - ls = warp_reduce_sum(ls); - - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - - S = S*ms + ls.x + ls.y; + if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) { + break; } - } - smax = warp_reduce_max(smax); - - // skip -INF blocks - if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { - continue; + half2 val = KQ2[j*(D_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + val = h2exp(val - KQ_max[j0/nwarps]); + KQ_rowsum_add += val; + KQ2[j*(D_padded/2) + k] = val; } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - // O = diag(ms)*O - for (int j = 0; j < Q16; ++j) { - half16x16_a mm; - half16x16_b lob; - - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + } - for (int i = 0; i < D16; ++i) { - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + __syncthreads(); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); - } + frag_b KQ_b[D/16][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 16) { + nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); } + } - // restore zeros - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + #pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); } - // O = O + (Q*K^T)*V - { - for (int cc = 0; cc < C16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); - } - - half16x16_a ms[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); - } + #pragma unroll + for (int k0 = 0; k0 < D; k0 += 16) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); - } - } + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); + #pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); } } } - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (lane_id < Q) { - ss[lane_id*T + 0] = S; - ss[lane_id*T + 1] = M[lane_id]; - } - } - - // reduce the warps sequentially - for (int sg = 1; sg < num_warps; ++sg) { __syncthreads(); - // each simdgroup stores its output to shared memory, reusing sq - if (warp_id == sg) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, + VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); } } __syncthreads(); - // the first simdgroup accumulates the results from the other simdgroups - if (warp_id == 0) { - for (int j = lane_id; j < Q; j += NW) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; - - const half M = __hmax(M0, M1); - - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); - - const half S = S0*ms0 + S1*ms1; - - ss[j*T + 0] = S; - ss[j*T + 1] = M; - - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int j = 0; j < Q16; ++j) { - half16x16_a ms0; - half16x16_a ms1; - half16x16_b t; - half16x16_acc t2; - - nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, zr); - - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); - - nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i]; } } - } - // store result to shared memory (reuse sq) - if (warp_id == 0) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } - } + __syncthreads(); } - // final rescale with 1/S and store to global memory - if (warp_id == 0) { - for (int j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; - - for (int i0 = 0; i0 < D; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D) { - break; - } - - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } } -#else - NO_DEVICE_CODE; -#endif } + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -461,133 +463,586 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); -#define NQPB 16 -#define NCPW 128 - - const int nqpb = NQPB; // queries per block - const int ncpw = NCPW; // cache values per warp (does not work for other values) - - GGML_ASSERT(NQPB <= 32); - - const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? - // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; - - dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); - dim3 block_dim(32, nwarps, 1); - - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + const int nwarps = Q->ne[0] / WARP_SIZE; + const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const int shmem = 0; + switch (Q->ne[0]) { + case 64: + flash_attn_vec_ext_f16<64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 80: + // flash_attn_vec_ext_f16<80> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 96: + flash_attn_vec_ext_f16<96> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 112: + // flash_attn_vec_ext_f16<112> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 128: + flash_attn_vec_ext_f16<128> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_vec_ext_f16<256> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + GGML_ASSERT(false); + break; + } + CUDA_CHECK(cudaGetLastError()); + return; + } - // increase shared memory limit to 96KB - //const size_t shmem_max = 96*1024; - //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } + const int frag_m = cols_per_block == 8 ? 32 : 16; + const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; + const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const size_t shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 96: - flash_attn_ext_f16<96, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 112: - flash_attn_ext_f16<112, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 256: - flash_attn_ext_f16<256, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + case 64: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<64, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<64, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<64, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 80: switch (cols_per_block) { + // case 8: + // fused_attn_vec_ext_f16<80, 8> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 16: + flash_attn_ext_f16<80, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<80, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<80, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 96: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<96, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<96, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<96, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<96, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 112: switch (cols_per_block) { + // case 8: + // fused_attn_vec_ext_f16<112, 8> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 16: + flash_attn_ext_f16<112, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<112, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<112, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 128: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<128, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<128, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<128, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<128, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 256: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<256, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<256, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<256, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 64: + // flash_attn_ext_f16<256, 64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; default: + GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); } From d59ac670bf92f18ba9db44f37fef93b002528ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 09:19:19 +0100 Subject: [PATCH 110/180] 16 cols for Phi-2 --- ggml-cuda/fattn.cu | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index ccb3c924609a4..d34924c3173e6 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,15 +579,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; + int cols_per_block = 16; + if (Q->ne[0] % 32 == 0) { + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; From 81da919864831948f292aeb0a5bd11eb5868bdb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 10:34:09 +0100 Subject: [PATCH 111/180] no vec for hs, no hs==256 ncols==32 for Volta --- ggml-cuda/common.cuh | 1 + ggml-cuda/fattn.cu | 72 ++++++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 33c8ed1da8d83..c245dd6ac009a 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -141,6 +141,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d34924c3173e6..43b9a9f4a3d11 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { const int nwarps = Q->ne[0] / WARP_SIZE; const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_vec_ext_f16<64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 64: + // flash_attn_vec_ext_f16<64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 80: // flash_attn_vec_ext_f16<80> // <<>> ( @@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // ); // break; - case 96: - flash_attn_vec_ext_f16<96> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 96: + // flash_attn_vec_ext_f16<96> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 112: // flash_attn_vec_ext_f16<112> // <<>> ( @@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (Q->ne[0] % 32 == 0) { if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { cols_per_block = 64; - } else if (Q->ne[1] >= 64) { + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; From 269374ed818dde2b267307d62be7ba59385aebfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 16:01:27 +0200 Subject: [PATCH 112/180] adjust kernel selection logic --- ggml-cuda/fattn.cu | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 43b9a9f4a3d11..f2c46008633d9 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block = 16; - if (Q->ne[0] % 32 == 0) { - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; From cca6d027a323b071d951f702ab3ede0d1937bb6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 18:39:02 +0200 Subject: [PATCH 113/180] 4 warps, 256 stride for all D --- ggml-cuda/fattn.cu | 633 +++++++++++---------------------------------- 1 file changed, 147 insertions(+), 486 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f2c46008633d9..aa85244fc52ce 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "fattn.cuh" #include @@ -176,8 +177,10 @@ static __global__ void flash_attn_vec_ext_f16( dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; } -template // D == head size -__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -206,6 +209,7 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; @@ -215,14 +219,13 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_b; typedef nvcuda::wmma::fragment frag_c; - constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; - constexpr int nthreads = nwarps*WARP_SIZE; - static_assert(nthreads % D == 0, "nthreads not divisible by D."); - constexpr int tc_vals_per_iter = nwarps*frag_m; - static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < nthreads); - constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); @@ -235,31 +238,43 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; - __shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll - for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { - const int i = i0 + tid; - if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { - break; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); } - - VKQ2[i] = make_half2(0.0f, 0.0f); } // Convert Q to half and apply scale, temporarily store in KQ: #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { - const int j = j0 + tid/D; - const int i = tid % D; - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } __syncthreads(); @@ -276,31 +291,27 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { - const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11; - + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { frag_c KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } - if (has_valid_data) { #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -311,18 +322,12 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if (j0 + nwarps > ncols && j >= ncols) { - break; - } half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - if (k0 + WARP_SIZE > D/2 && k >= D/2) { - break; - } - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]); + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); @@ -330,20 +335,14 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - if (k0 + WARP_SIZE > D/2 && k >= D/2) { - break; - } - if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) { - break; - } - half2 val = KQ2[j*(D_padded/2) + k]; + half2 val = KQ2[j*(kqs_padded/2) + k]; val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); val = h2exp(val - KQ_max[j0/nwarps]); KQ_rowsum_add += val; - KQ2[j*(D_padded/2) + k] = val; + KQ2[j*(kqs_padded/2) + k] = val; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); @@ -353,47 +352,46 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[D/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { + nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); } } - frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; + frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { - #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); } - #pragma unroll - for (int k0 = 0; k0 < D; k0 += 16) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); - #pragma unroll + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } __syncthreads(); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, - VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); } } @@ -403,16 +401,19 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if (j0 + nwarps > ncols && j >= ncols) { - break; - } #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D/2 && i >= D/2) { break; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i]; + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -422,7 +423,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { + if (ncols*blockIdx.x + j >= ne01) { return; } const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); @@ -437,6 +438,50 @@ static __global__ void flash_attn_ext_f16( } } +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -580,7 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { cols_per_block = 64; } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; @@ -590,451 +635,67 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; + const int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0; switch (Q->ne[0]) { case 64: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<64, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<64, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<64, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(64, 8, nwarps); + FATTN_SWITCH_CASE(64, 16, nwarps); + FATTN_SWITCH_CASE(64, 32, nwarps); + FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 80: switch (cols_per_block) { - // case 8: - // fused_attn_vec_ext_f16<80, 8> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 16: - flash_attn_ext_f16<80, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<80, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<80, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // FATTN_SWITCH_CASE(80, 8, nwarps); + FATTN_SWITCH_CASE(80, 16, nwarps); + FATTN_SWITCH_CASE(80, 32, nwarps); + // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 96: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<96, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<96, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<96, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<96, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(96, 8, nwarps); + FATTN_SWITCH_CASE(96, 16, nwarps); + FATTN_SWITCH_CASE(96, 32, nwarps); + FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 112: switch (cols_per_block) { - // case 8: - // fused_attn_vec_ext_f16<112, 8> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 16: - flash_attn_ext_f16<112, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<112, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<112, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // FATTN_SWITCH_CASE(112, 8, nwarps); + FATTN_SWITCH_CASE(112, 16, nwarps); + FATTN_SWITCH_CASE(112, 32, nwarps); + // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 128: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<128, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<128, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<128, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<128, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(128, 8, nwarps); + FATTN_SWITCH_CASE(128, 16, nwarps); + FATTN_SWITCH_CASE(128, 32, nwarps); + // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 256: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<256, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<256, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<256, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - // case 64: - // flash_attn_ext_f16<256, 64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; + FATTN_SWITCH_CASE(256, 8, nwarps); + FATTN_SWITCH_CASE(256, 16, nwarps); + FATTN_SWITCH_CASE(256, 32, nwarps); + // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); From 68d793bee816e44876b22232891ee7bab51ee5e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 15:54:50 +0200 Subject: [PATCH 114/180] no ncols == 64 --- ggml-cuda/fattn.cu | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index aa85244fc52ce..19108044e762f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; @@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(64, 8, nwarps); FATTN_SWITCH_CASE(64, 16, nwarps); FATTN_SWITCH_CASE(64, 32, nwarps); - FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(80, 8, nwarps); FATTN_SWITCH_CASE(80, 16, nwarps); FATTN_SWITCH_CASE(80, 32, nwarps); - // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(96, 8, nwarps); FATTN_SWITCH_CASE(96, 16, nwarps); FATTN_SWITCH_CASE(96, 32, nwarps); - FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(112, 8, nwarps); FATTN_SWITCH_CASE(112, 16, nwarps); FATTN_SWITCH_CASE(112, 32, nwarps); - // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(128, 8, nwarps); FATTN_SWITCH_CASE(128, 16, nwarps); FATTN_SWITCH_CASE(128, 32, nwarps); - // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(256, 8, nwarps); FATTN_SWITCH_CASE(256, 16, nwarps); FATTN_SWITCH_CASE(256, 32, nwarps); - // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); From 3f777acf06a0c21780e2e6b1d5b0b8e9a2fbd922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 16:41:56 +0200 Subject: [PATCH 115/180] Multiple parallel blocks for batch size 1 --- ggml-cuda/fattn.cu | 417 ++++++++++++++++++++++++++++++--------------- 1 file changed, 283 insertions(+), 134 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 19108044e762f..4b51f1b747cf3 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -29,14 +29,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -template // D == head size -__launch_bounds__(D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -60,20 +63,25 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne31*blockIdx.x; + const half * maskh = (const half *) mask; + + if (parallel_blocks == 1) { + Q_f2 += blockIdx.x*nb01/sizeof(float2); + maskh += blockIdx.x*ne11; + } const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - constexpr int nwarps = D/WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < nwarps*WARP_SIZE); - __shared__ half KQ[D]; - KQ[tid] = 0.0f; + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; half kqmax = -INFINITY; @@ -85,7 +93,6 @@ static __global__ void flash_attn_vec_ext_f16( kqmax_shared[threadIdx.x] = -INFINITY; kqsum_shared[threadIdx.x] = 0.0f; } - __syncthreads(); // Convert Q to half2 and store in registers: @@ -102,14 +109,15 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -153,19 +161,25 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); + if (tid < D) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; - VKQ += V_k*KQ2[k0/2]; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } } } + if (tid >= D) { + kqsum = 0.0f; + } + kqsum = warp_reduce_sum(kqsum); if (threadIdx.x == 0) { kqsum_shared[threadIdx.y] = kqsum; @@ -174,12 +188,22 @@ static __global__ void flash_attn_vec_ext_f16( kqsum = kqsum_shared[threadIdx.x]; kqsum = warp_reduce_sum(kqsum); - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; -} + if (tid >= D) { + return; + } -#define FATTN_KQ_STRIDE 256 + if (parallel_blocks == 1) { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; + } else { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel + if (tid == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); + } + } +} + +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -187,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -228,10 +253,15 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + const half2 * mask2 = (half2 *) mask; + + if (parallel_blocks == 1) { + Q_f += blockIdx.x * ncols*nb01/sizeof(float); + mask2 += blockIdx.x * ncols*ne11/2; + } const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -273,7 +303,11 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + if (parallel_blocks == 1) { + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } else { + KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } } @@ -291,7 +325,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -420,22 +455,75 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } + if (parallel_blocks == 1) { #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } + return; + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + } + + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const half2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half2 meta[parallel_blocks]; + if (tid < parallel_blocks) { + meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; } + + __syncthreads(); + + half kqmax = __low2half(meta[0]); +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = __hmax(kqmax, __low2half(meta[l])); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } constexpr int get_max_power_of_2(int x) { @@ -462,26 +550,26 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, nullptr, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -508,88 +596,135 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { - const int nwarps = Q->ne[0] / WARP_SIZE; - const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + if (Q->ne[1] == 1) { + constexpr int parallel_blocks = 4; + + ggml_cuda_pool_alloc dst_tmp(ctx.pool()); + ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); + + const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; + + // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: + constexpr int nwarps_tc = 4; + constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); + + const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); + const dim3 block_dim_combine(Q->ne[0], 1, 1); + const int shmem_combine = 0; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + switch (Q->ne[0]) { - // case 64: - // flash_attn_vec_ext_f16<64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 80: - // flash_attn_vec_ext_f16<80> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 96: - // flash_attn_vec_ext_f16<96> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 112: - // flash_attn_vec_ext_f16<112> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; + case 64: + flash_attn_vec_ext_f16<64, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<64, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 80: + flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<80, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 96: + flash_attn_vec_ext_f16<96, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<96, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 112: + flash_attn_vec_ext_f16<112, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<112, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; case 128: - flash_attn_vec_ext_f16<128> + flash_attn_vec_ext_f16<128, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -598,15 +733,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<128, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; case 256: - flash_attn_vec_ext_f16<256> + flash_attn_vec_ext_f16<256, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -615,6 +757,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<256, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; default: GGML_ASSERT(false); @@ -633,7 +782,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = 4; + constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0; From e1ecd3b1290adf0086b7b5a45fe97d22afe6a963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 10:27:34 +0200 Subject: [PATCH 116/180] fix compile warnings --- ggml-cuda/fattn.cu | 48 ++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 4b51f1b747cf3..d1d018dc7a7e8 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); - } - return x; -#else - GGML_UNUSED(x); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -} +// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +// #pragma unroll +// for (int mask = 16; mask > 0; mask >>= 1) { +// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); +// } +// return x; +// #else +// GGML_UNUSED(x); +// NO_DEVICE_CODE; +// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +// } #define FATTN_KQ_STRIDE 256 @@ -472,21 +472,20 @@ static __global__ void flash_attn_ext_f16( dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } } - return; - } - + } else { #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; - } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } } } @@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } else { cols_per_block = 8; } - const int frag_m = cols_per_block == 8 ? 32 : 16; constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); From bb0d51accd7e99c151088c11f1b6774a753dc05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:13:46 +0200 Subject: [PATCH 117/180] fix excessive KQ_b loads --- ggml-cuda/fattn.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d1d018dc7a7e8..dcd2129a2cfc9 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*kqs_padded + k, + kqs_padded); } } @@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } From c63dfdf765c48a0f78e162d0c02c7d69cbbc3083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:58:59 +0200 Subject: [PATCH 118/180] fix cmake build --- ggml-cuda/common.cuh | 26 ++++++++++++-------------- ggml-cuda/fattn.cu | 38 ++++++++++++-------------------------- 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index c245dd6ac009a..510ca6281471e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -271,7 +271,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +283,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,18 +292,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} #if defined(GGML_USE_HIPBLAS) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dcd2129a2cfc9..1d29346c7453b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,32 +3,6 @@ #include -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -} - -// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// #pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -// #else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// } - #define FATTN_KQ_STRIDE 256 template // D == head size @@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); @@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA //In this kernel Q, K, V are matrices while i, j, k are matrix indices. static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); @@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16( __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA } template // D == head size @@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results( } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } constexpr int get_max_power_of_2(int x) { From ee19a4ab7eba53347bde5fa3b3d38e38f7427f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 17:26:22 +0200 Subject: [PATCH 119/180] fix KV cache padding, NaN from INFINITY (#6438) --- ggml-cuda/fattn.cu | 15 +++++++++------ llama.cpp | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 1d29346c7453b..91ef5551e025a 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -4,6 +4,7 @@ #include #define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16( KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; - half kqmax = -INFINITY; + half kqmax = -HALF_MAX_HALF; half kqsum = 0.0f; __shared__ half kqmax_shared[WARP_SIZE]; __shared__ half kqsum_shared[WARP_SIZE]; if (threadIdx.y == 0) { - kqmax_shared[threadIdx.x] = -INFINITY; + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; kqsum_shared[threadIdx.x] = 0.0f; } __syncthreads(); @@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16( if (tid < D) { #pragma unroll for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16( __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; + half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + ggml_cuda_set_device(ctx.device); const cudaStream_t main_stream = ctx.stream(); diff --git a/llama.cpp b/llama.cpp index 9ea9886fe3b89..b50588e4467b0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9973,7 +9973,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; From 599ce84a71512b72bf4fd6a248e7725f646eb1a8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Apr 2024 12:00:35 +0300 Subject: [PATCH 120/180] llama : flash_attn cparam + fix defrag --- common/common.cpp | 6 + common/common.h | 1 + llama.cpp | 345 +++++++++++++++++++++++++--------------------- llama.h | 1 + 4 files changed, 194 insertions(+), 159 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cf69535e2d1f5..fbff8cf13effc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -900,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; + return true; + } if (arg == "--color") { params.use_color = true; return true; @@ -1836,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -2673,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index cca44268e6df5..78a1a34021b60 100644 --- a/common/common.h +++ b/common/common.h @@ -148,6 +148,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/llama.cpp b/llama.cpp index 1081899fbaa3f..0b5ee9ef569c6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -107,8 +107,6 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 60 -#define LLAMA_FLASH_ATTN - // // logging // @@ -1899,6 +1897,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -5938,15 +5937,17 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -5959,28 +5960,26 @@ static void llm_build_kv_store( // important: storing RoPE-ed version of K in the KV cache! ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); -#if defined(LLAMA_FLASH_ATTN) - // NOTE: the V cache is not transposed when using FLASH attention !! - struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, - (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); - cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); - GGML_UNUSED(n_ctx); -#else - // compute the transposed [n_tokens, n_embd] V matrix - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + } else { + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); + cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), - (kv_head)*ggml_element_size(kv.v_l[il])); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); -#endif + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + } } static struct ggml_tensor * llm_build_norm( @@ -6111,6 +6110,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6118,12 +6118,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -6143,97 +6143,100 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; -#if defined(LLAMA_FLASH_ATTN) - GGML_UNUSED(model); - GGML_UNUSED(n_ctx); + if (cparams.flash_attn) { + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); - GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv.v_l[il]->type, n_embd_head_k), - 0); - cb(v, "v", il); + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); - cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); -#else - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } #if defined(GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + if (hparams.f_max_alibi_bias > 0.0f) { + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else #endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - } + { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } - GGML_ASSERT(kv.size == n_ctx); + GGML_ASSERT(kv.size == n_ctx); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); -#endif + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } ggml_build_forward_expand(graph, cur); @@ -6253,6 +6256,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6262,7 +6266,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6276,12 +6279,12 @@ static struct ggml_tensor * llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6323,6 +6326,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6369,6 +6374,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6483,15 +6489,31 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); - ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6640,9 +6662,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6846,9 +6868,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6953,9 +6975,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7073,9 +7095,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7198,9 +7220,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7343,9 +7365,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7447,9 +7469,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7651,9 +7673,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7747,9 +7769,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8040,9 +8062,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8171,14 +8193,15 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8320,9 +8343,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8438,9 +8461,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8557,9 +8580,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8671,9 +8694,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8817,9 +8840,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -8919,9 +8942,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9022,9 +9045,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9129,9 +9152,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9245,9 +9268,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9362,9 +9385,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9492,9 +9515,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9613,9 +9636,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9732,9 +9755,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10022,9 +10045,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -11016,7 +11039,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -14626,6 +14651,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -14795,6 +14821,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; diff --git a/llama.h b/llama.h index b5da686f7b7e5..77c288eb25543 100644 --- a/llama.h +++ b/llama.h @@ -270,6 +270,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted From 405385726ef7432b65b9e63dd7a63c18765eb376 Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 17 Apr 2024 14:05:02 +0200 Subject: [PATCH 121/180] server: support flash_attn param --- examples/server/server.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 634e653ada284..f1754b60b7fe8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2722,6 +2722,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; + } else if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; } else if (arg == "-np" || arg == "--parallel") { if (++i >= argc) { invalid_param = true; From 5668c79ea092b7bff95e1fce96e3de717c31349d Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 17 Apr 2024 23:26:29 +0200 Subject: [PATCH 122/180] server: bench: enable flash_attn param --- examples/server/bench/bench.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 6ca637bddc3a1..86c5de101445c 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -268,6 +268,7 @@ def start_server_background(args): server_args.extend(['--defrag-thold', "0.1"]) server_args.append('--cont-batching') server_args.append('--metrics') + server_args.append('--flash-attn') server_args.extend(['--log-format', "text"]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From 34f93bbb39965dd40fe8ad717902d6d109b64afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 9 Apr 2024 11:39:16 +0200 Subject: [PATCH 123/180] CUDA: refactor host code, dyn. par. blocks --- ggml-cuda.cu | 1 + ggml-cuda/common.cuh | 6 + ggml-cuda/fattn.cu | 542 +++++++++++++++++++------------------------ 3 files changed, 248 insertions(+), 301 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 11adbabd655e1..2cf6c8d98bd89 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -141,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b0149b7be22b3..989780dbce88c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -390,6 +390,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -403,6 +408,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 91ef5551e025a..5f1345a7fe94f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -36,18 +36,17 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask; - - if (parallel_blocks == 1) { - Q_f2 += blockIdx.x*nb01/sizeof(float2); - maskh += blockIdx.x*ne11; - } + const half * maskh = (const half *) mask + ne11*ic; const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); @@ -85,7 +84,7 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; @@ -168,18 +167,19 @@ static __global__ void flash_attn_vec_ext_f16( return; } + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); if (parallel_blocks == 1) { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; - } else { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; - if (tid == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); - } + if (parallel_blocks == 1 || tid != 0) { + return; } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -212,8 +212,12 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; @@ -233,15 +237,10 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask; - - if (parallel_blocks == 1) { - Q_f += blockIdx.x * ncols*nb01/sizeof(float); - mask2 += blockIdx.x * ncols*ne11/2; - } + const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -283,11 +282,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - if (parallel_blocks == 1) { - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } else { - KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -305,8 +300,7 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -439,41 +433,39 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } - if (parallel_blocks == 1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; - } + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; } - } else { + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + half dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; } + + half2 dst_meta_val = KQ_max[j0/nwarps]; + reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#endif // FP16_MMA_AVAILABLE } template // D == head size @@ -482,7 +474,10 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -513,7 +508,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } constexpr int get_max_power_of_2(int x) { @@ -540,26 +535,124 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, nullptr, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -583,259 +676,106 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - const cudaStream_t main_stream = ctx.stream(); - - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); - - if (Q->ne[1] == 1) { + if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { constexpr int parallel_blocks = 4; - - ggml_cuda_pool_alloc dst_tmp(ctx.pool()); - ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); - - const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; - const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const int shmem = 0; - - // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: - constexpr int nwarps_tc = 4; - constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); - - const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); - const dim3 block_dim_combine(Q->ne[0], 1, 1); - const int shmem_combine = 0; - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - switch (Q->ne[0]) { case 64: - flash_attn_vec_ext_f16<64, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<64, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 80: - flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<80, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 96: - flash_attn_vec_ext_f16<96, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<96, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 112: - flash_attn_vec_ext_f16<112, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<112, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 128: - flash_attn_vec_ext_f16<128, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<128, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 256: - flash_attn_vec_ext_f16<256, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<256, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); return; } - int cols_per_block; - if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } - constexpr int nwarps = 4; - const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const size_t shmem = 0; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - switch (Q->ne[0]) { - case 64: switch (cols_per_block) { - FATTN_SWITCH_CASE(64, 8, nwarps); - FATTN_SWITCH_CASE(64, 16, nwarps); - FATTN_SWITCH_CASE(64, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 80: switch (cols_per_block) { - // FATTN_SWITCH_CASE(80, 8, nwarps); - FATTN_SWITCH_CASE(80, 16, nwarps); - FATTN_SWITCH_CASE(80, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 96: switch (cols_per_block) { - FATTN_SWITCH_CASE(96, 8, nwarps); - FATTN_SWITCH_CASE(96, 16, nwarps); - FATTN_SWITCH_CASE(96, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 112: switch (cols_per_block) { - // FATTN_SWITCH_CASE(112, 8, nwarps); - FATTN_SWITCH_CASE(112, 16, nwarps); - FATTN_SWITCH_CASE(112, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 128: switch (cols_per_block) { - FATTN_SWITCH_CASE(128, 8, nwarps); - FATTN_SWITCH_CASE(128, 16, nwarps); - FATTN_SWITCH_CASE(128, 32, nwarps); default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; - case 256: switch (cols_per_block) { - FATTN_SWITCH_CASE(256, 8, nwarps); - FATTN_SWITCH_CASE(256, 16, nwarps); - FATTN_SWITCH_CASE(256, 32, nwarps); + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); + return; } From 6a3b84236de279f0fe012cfca0c168472526b696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 13 Apr 2024 22:05:43 +0200 Subject: [PATCH 124/180] fix flash_attn_vec_f16 race condition --- ggml-cuda/fattn.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 5f1345a7fe94f..36479b2170979 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16( VKQ += V_k*KQ2[k0/2]; } } + + __syncthreads(); } if (tid >= D) { @@ -547,7 +549,7 @@ template void launch_fattn_vec_f16( dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } - constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -561,7 +563,7 @@ template void launch_fattn_vec_f16( (const char *) K->data, (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -572,7 +574,7 @@ template void launch_fattn_vec_f16( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { + if (parallel_blocks == 1) { return; } From ef9e1593f33df5dbc8f89b927a8a7bd9dfc9e6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 15 Apr 2024 16:05:07 +0200 Subject: [PATCH 125/180] flush softmax exp below threshold to 0 --- ggml-cuda/fattn.cu | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 36479b2170979..f6289822e0ea0 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,8 +3,9 @@ #include -#define FATTN_KQ_STRIDE 256 -#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); + half2 val = KQ2[j*(kqs_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, val); + KQ2[j*(kqs_padded/2) + k] = val; } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; + KQ_max_scale[j0/nwarps] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; KQ_max[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - val = h2exp(val - KQ_max[j0/nwarps]); + const half2 diff = val - KQ_max[j0/nwarps]; + val = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &val) &= ftz_mask; KQ_rowsum_add += val; KQ2[j*(kqs_padded/2) + k] = val; } @@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + const half diff = __low2half(meta[l]) - kqmax; + float KQ_max_scale = hexp(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * __high2float(meta[l]); From a5b0e2dea018cfac5ee478aac0d780eef391b30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 15:58:21 +0200 Subject: [PATCH 126/180] store temp KQ in registers --- ggml-cuda/fattn.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f6289822e0ea0..b889cdb3b9b01 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, val); - KQ2[j*(kqs_padded/2) + k] = val; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; @@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - const half2 diff = val - KQ_max[j0/nwarps]; - val = h2exp(diff); + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &val) &= ftz_mask; - KQ_rowsum_add += val; - KQ2[j*(kqs_padded/2) + k] = val; + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); From 0bc67dd1c81c15f04096985f9a85d81b431767b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 16:22:29 +0200 Subject: [PATCH 127/180] Calculate KQ as FP32 if KQV has GGML_PREC_F32 --- ggml-cuda/fattn.cu | 286 +++++++++++++++++++++++++++++++++------------ 1 file changed, 213 insertions(+), 73 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index b889cdb3b9b01..dda344531335c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn.cuh" +#include #include #define FATTN_KQ_STRIDE 256 @@ -185,7 +186,8 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -229,7 +231,8 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -238,12 +241,14 @@ static __global__ void flash_attn_ext_f16( // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: constexpr int D_padded = D + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -251,14 +256,29 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_KQ = ncols*kqs_padded*kqar; constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; - half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -307,7 +327,7 @@ static __global__ void flash_attn_ext_f16( // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c KQ_c[ncols/frag_n]; + frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); @@ -323,7 +343,7 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -335,45 +355,90 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - } + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } - half2 KQ_max_new = KQ_max[j0/nwarps]; + float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; - KQ_max_scale[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; - KQ_max[j0/nwarps] = KQ_max_new; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } } __syncthreads(); @@ -386,12 +451,12 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; nvcuda::wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*kqs_padded + k, - kqs_padded); + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); } } - frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll @@ -431,6 +496,14 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -443,7 +516,7 @@ static __global__ void flash_attn_ext_f16( for (int l = 0; l < VKQ_ratio; ++l) { VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -458,14 +531,20 @@ static __global__ void flash_attn_ext_f16( } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + #pragma unroll for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D && i >= D) { break; } - half dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*D_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } @@ -476,7 +555,12 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val = KQ_max[j0/nwarps]; + half2 dst_meta_val; + if (std::is_same::value) { + reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val = KQ_max_h2[j0/nwarps]; + } reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } @@ -606,7 +690,7 @@ template void launch_fattn_vec_f16( CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16_impl( +template void launch_fattn_f16_impl( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { @@ -626,7 +710,7 @@ template void launc float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - flash_attn_ext_f16 + flash_attn_ext_f16 <<>> ( (const char *) Q->data, (const char *) K->data, @@ -657,21 +741,21 @@ template void launc CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16( +template void launch_fattn_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream ) { const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; if (4*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } if (2*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -696,15 +780,73 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { constexpr int parallel_blocks = 4; switch (Q->ne[0]) { case 64: launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; - case 96: - launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; case 128: launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; @@ -718,23 +860,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -748,22 +888,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -776,22 +916,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); From 2f538b9547ec2c2c67be0d41ed96d33c141354fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 16:29:28 +0200 Subject: [PATCH 128/180] Add __hgt2_mask implementation for CUDA 11 --- ggml-cuda/common.cuh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 989780dbce88c..ac6de643d668e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -306,6 +306,13 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { + const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 From 87968de9a99d9820d20aeb5a15211f4ab5efde83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 17:31:03 +0200 Subject: [PATCH 129/180] fix KQ FP32 precision fpr parallel_blocks > 1 --- ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dda344531335c..4cf2907e8d10c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16( if (parallel_blocks == 1 || tid != 0) { return; } - dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val; + float2 dst_meta_val; if (std::is_same::value) { - reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + dst_meta_val.x = KQ_max_f[j0/nwarps]; } else { - dst_meta_val = KQ_max_h2[j0/nwarps]; + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } - reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta_val.y = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else @@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const half2 * __restrict__ VKQ_meta, + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { #if FP16_AVAILABLE VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; @@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results( const int tid = threadIdx.x; __builtin_assume(tid < D); - __shared__ half2 meta[parallel_blocks]; - if (tid < parallel_blocks) { - meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; } __syncthreads(); - half kqmax = __low2half(meta[0]); + float kqmax = meta[0].x; #pragma unroll for (int l = 1; l < parallel_blocks; ++l) { - kqmax = __hmax(kqmax, __low2half(meta[l])); + kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - const half diff = __low2half(meta[l]) - kqmax; - float KQ_max_scale = hexp(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; - VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + VKQ_denominator += KQ_max_scale * meta[l].y; } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; @@ -643,8 +643,8 @@ template void launch_fattn_vec_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -694,8 +694,8 @@ template dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); From 260cdb2d082d1658d1c6b693c4cbf77754873886 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 14:28:19 +0300 Subject: [PATCH 130/180] llama-bench : add -fa,--flash-attn arg --- examples/llama-bench/llama-bench.cpp | 30 +++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8b532c8b6a98a..95c3095dd04da 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -174,6 +174,7 @@ struct cmd_params { std::vector split_mode; std::vector main_gpu; std::vector no_kv_offload; + std::vector flash_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = { /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, + /* flash_attn */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); @@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "-fa" || arg == "--flash-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } + if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -498,6 +509,7 @@ struct cmd_params_instance { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -532,6 +544,7 @@ struct cmd_params_instance { cparams.type_k = type_k; cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; + cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; return cparams; @@ -554,6 +567,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) + for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -572,6 +586,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -596,6 +611,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -633,6 +649,7 @@ struct test { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -657,6 +674,7 @@ struct test { split_mode = inst.split_mode; main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; + flash_attn = inst.flash_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -731,7 +749,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", + "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -753,7 +771,7 @@ struct test { } if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -787,7 +805,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -955,6 +973,9 @@ struct markdown_printer : public printer { if (field == "no_kv_offload") { return "nkvo"; } + if (field == "flash_attn") { + return "fa"; + } if (field == "use_mmap") { return "mmap"; } @@ -1001,6 +1022,9 @@ struct markdown_printer : public printer { if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) { fields.emplace_back("no_kv_offload"); } + if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { + fields.emplace_back("flash_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } From 105332cc17b8bd8f3989606b59489ed85eefe04f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 14:33:07 +0300 Subject: [PATCH 131/180] metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel --- ggml-metal.m | 119 ++++++++++++++------ ggml-metal.metal | 274 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 361 insertions(+), 32 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e8613dbee04d2..407f94eb224cf 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -183,6 +183,8 @@ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -621,12 +623,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ASSERT(false && "add template specialization for this size"); - } + if (ne01 > 1 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } } // TODO: extend if necessary @@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + // half8x8 kernel + if (ne01 > 1 || (ne00%128 != 0)) { + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); - // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + // require power of 2 + //{ + // int64_t nsgm = 1; + // while (nsgm < nsg) { + // nsgm *= 2; + // } + // GGML_ASSERT(nsg == nsgm); + //} + + const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index be47db86ebbd3..de6072e93470a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2494,6 +2494,280 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short D8 = D/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + 1); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S = { 0.0h }; + half M = { -HALF_MAX_HALF }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += mq[i] * mk; + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask + if (tiisg == 0) { + half4 mm = mp4[ic/4 + cc]; + mqk = mqk*scale + mm; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const half m = M; + const half s = ss[p]; + + M = simd_max(max(M, s)); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; + + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; + + const half M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const half S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From c16a7c26882669a0d2ed7ef592cde3b0227248f2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 20:08:52 +0300 Subject: [PATCH 132/180] metal : use F32 attention accumulators --- ggml-metal.m | 15 +---- ggml-metal.metal | 156 +++++++++++++++++++++++------------------------ ggml.c | 3 +- 3 files changed, 81 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 407f94eb224cf..f4a831b52a9f9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); @@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); int64_t nsg = 1; @@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute( } nsg /= 2; - // require power of 2 - //{ - // int64_t nsgm = 1; - // while (nsgm < nsg) { - // nsgm *= 2; - // } - // GGML_ASSERT(nsg == nsgm); - //} - - const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index de6072e93470a..36b87b2f0750a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2169,12 +2169,13 @@ kernel void kernel_flash_attn_ext_f16( const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) const short T4 = T/4; // shared memory size per query in (half4) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2202,15 +2203,15 @@ kernel void kernel_flash_attn_ext_f16( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*T + i] = 0.0h; + ss[j*TF + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -INFINITY }; + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; // assume K and V are same shape const short ne22 = ne12; @@ -2248,7 +2249,7 @@ kernel void kernel_flash_attn_ext_f16( device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2261,9 +2262,9 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk[Q8]; + simdgroup_float8x8 mqk[Q8]; for (short j = 0; j < Q8; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); + mqk[j] = make_filled_simdgroup_matrix(0.h); } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2283,48 +2284,48 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); } } } // used to detect blocks full of -INF - half smax = -INFINITY; + float smax = -INFINITY; // online softmax if (C == 32) { - half ms[Q]; + float ms[Q]; for (short j = 0; j < Q; ++j) { const short p = tiisg; - const half m = M[j]; - const half s = ss[j*T + p]; + const float m = M[j]; + const float s = ss[j*TF + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss[j*TF + p] = vs; } // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; + ss[tiisg*TF + C + tiisg] = ms[tiisg]; } } else { - half ms[Q]; + float ms[Q]; for (short j = 0; j < Q; ++j) { - const half m = M[j]; + const float m = M[j]; for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const float s = ss[j*TF + p]; smax = max(smax, s); M[j] = max(M[j], s); @@ -2333,20 +2334,20 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + ms[j] = exp(m - M[j]); // local sum - half ls = 0.0h; + float ls = 0.0h; for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const float s = ss[j*TF + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const float vs = exp(s - M[j]); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss[j*TF + p] = vs; } S[j] = S[j]*ms[j] + simd_sum(ls); @@ -2354,7 +2355,7 @@ kernel void kernel_flash_attn_ext_f16( // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; + ss[tiisg*TF + C + tiisg] = ms[tiisg]; } } @@ -2365,8 +2366,8 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[j][i], mm, lo[j][i]); @@ -2383,8 +2384,8 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mv; - simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false); simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); } @@ -2396,16 +2397,16 @@ kernel void kernel_flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; } } } // reduce the warps sequentially for (short sg = 1; sg < nsg; ++sg) { - half S = { 0.0h }; - half M = { -INFINITY }; + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2423,36 +2424,36 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; M = max(M0, M1); - const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 t; - simdgroup_half8x8 ms0; - simdgroup_half8x8 ms1; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; - simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); - simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); + simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); @@ -2478,7 +2479,7 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; + const float S = ss[j*TF + 0]; for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; @@ -2494,8 +2495,6 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; -#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. - template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, @@ -2539,18 +2538,16 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq1 = tgpig[0]; const short D4 = D/4; - const short D8 = D/8; const short NW = N_SIMDWIDTH; const short SH = (C + 1); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) - const short T4 = T/4; // shared memory size per query in (half4) + const short T = D + 2*nsg*SH; // shared memory size per query in (half) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) half4 lo[D4/NW]; @@ -2579,8 +2576,8 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S = { 0.0h }; - half M = { -HALF_MAX_HALF }; + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; // assume K and V are same shape const short ne22 = ne12; @@ -2628,7 +2625,7 @@ kernel void kernel_flash_attn_ext_vec_f16( { #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - half4 mqk = { 0.0h }; + float4 mqk = { 0.0h }; device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2642,7 +2639,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = pk4[i + 2*(nb11/8)]; mk[3] = pk4[i + 3*(nb11/8)]; - mqk += mq[i] * mk; + mqk += (float4) (mq[i] * mk); } // reduce the results from the threads in the simdgroup @@ -2654,7 +2651,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask if (tiisg == 0) { - half4 mm = mp4[ic/4 + cc]; + float4 mm = (float4) mp4[ic/4 + cc]; mqk = mqk*scale + mm; ss4[cc] = mqk; @@ -2666,13 +2663,13 @@ kernel void kernel_flash_attn_ext_vec_f16( { const short p = tiisg; - const half m = M; - const half s = ss[p]; + const float m = M; + const float s = ss[p]; M = simd_max(max(M, s)); - const half ms = exp(m - M); - const half vs = exp(s - M); + const float ms = exp(m - M); + const float vs = exp(s - M); S = S*ms + simd_sum(vs); @@ -2696,6 +2693,7 @@ kernel void kernel_flash_attn_ext_vec_f16( #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; @@ -2724,18 +2722,18 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const half S0 = ss[ 0]; - const half S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; - const half M0 = ss[ 1]; - const half M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; - const half M = max(M0, M1); + const float M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); - const half S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[0] = S; @@ -2756,7 +2754,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - const half S = ss[0]; + const float S = ss[0]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; diff --git a/ggml.c b/ggml.c index f50cb948daeab..f2bbfa6f2273b 100644 --- a/ggml.c +++ b/ggml.c @@ -14882,12 +14882,13 @@ static void ggml_compute_forward_flash_attn_ext( struct ggml_tensor * dst) { switch (dst->op_params[1]) { case GGML_PREC_DEFAULT: + case GGML_PREC_F32: { + // uses F32 accumulators ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { - // TODO: implement F32 precision GGML_ASSERT(false); } break; } From 9ca869876eddfb0c51d3b80440a11abdd6a3fe18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 21:41:32 +0300 Subject: [PATCH 133/180] batched-bench : add fattn arg --- examples/batched-bench/batched-bench.cpp | 28 ++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1e34de620a41b..2924d8116f44f 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -32,7 +32,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] \n" , argv[0]); printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; @@ -41,6 +41,7 @@ int main(int argc, char ** argv) { int n_kv_max = 2048; int n_batch = 2048; int n_ubatch = 512; + bool flash_attn = false; int is_pp_shared = 0; int n_gpu_layers = 0; @@ -66,23 +67,27 @@ int main(int argc, char ** argv) { } if (argc >= 6) { - is_pp_shared = std::atoi(argv[5]); + flash_attn = std::atoi(argv[5]); } if (argc >= 7) { - n_gpu_layers = std::atoi(argv[6]); + is_pp_shared = std::atoi(argv[6]); } if (argc >= 8) { - n_pp = parse_list(argv[7]); + n_gpu_layers = std::atoi(argv[7]); } if (argc >= 9) { - n_tg = parse_list(argv[8]); + n_pp = parse_list(argv[8]); } if (argc >= 10) { - n_pl = parse_list(argv[9]); + n_tg = parse_list(argv[9]); + } + + if (argc >= 11) { + n_pl = parse_list(argv[10]); } // init LLM @@ -108,10 +113,11 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = n_batch; - ctx_params.n_ubatch = n_ubatch; + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = n_batch; + ctx_params.n_ubatch = n_ubatch; + ctx_params.flash_attn = flash_attn; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -169,7 +175,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); From 74d57f95136c8391756c8144ad12a517901bd2e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 13:49:57 +0300 Subject: [PATCH 134/180] llama : simplify llama_build_kv_store ggml-ci --- llama.cpp | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4034d25181aaa..d828dd786e679 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5963,29 +5963,27 @@ static void llm_build_kv_store( (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! + // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using FLASH attention !! - struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, - (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); - cb(v_cache_view, "v_cache_view", il); + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); - } else { - // compute the transposed [n_tokens, n_embd] V matrix - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); - cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = nullptr; - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + v_cur = ggml_transpose(ctx, v_cur); } + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } static struct ggml_tensor * llm_build_norm( @@ -6169,11 +6167,6 @@ static struct ggml_tensor * llm_build_kqv( if (model.arch == LLM_ARCH_PHI2) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); } else { @@ -14879,6 +14872,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); From e32b281743c98411d1eaf22823d3eea2023c3502 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 14:04:56 +0300 Subject: [PATCH 135/180] llama : adapt build_olmo to changes --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index add75818b2bdc..a7ce50dd30efa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10287,9 +10287,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { From 703c6e6528d184eaf6ea0bed21cd350e095c63f4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 14:20:41 +0300 Subject: [PATCH 136/180] ggml : fix arm fp16 store on windows --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 4c749ecdafa5b..76ca79e660aa0 100644 --- a/ggml.c +++ b/ggml.c @@ -963,7 +963,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -989,7 +989,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL From 97eaece7d6537b97a01d30a71a7ce4511fa97e25 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 15:30:27 +0300 Subject: [PATCH 137/180] metal : clean-up --- ggml-metal.m | 353 ++++++++++++++++++++++++++------------------------- 1 file changed, 178 insertions(+), 175 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 44a3d05f354e5..68eb49d5b1baa 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -475,175 +475,175 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } [metal_library release]; @@ -2560,7 +2560,9 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - if (ne01 > 1 || (ne00%128 != 0)) { + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; @@ -2576,6 +2578,8 @@ static enum ggml_status ggml_metal_graph_compute( } } } else { + use_vec_kernel = true; + switch (ne00) { case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; @@ -2588,7 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute( } } - // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2619,8 +2622,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // half8x8 kernel - if (ne01 > 1 || (ne00%128 != 0)) { + if (!use_vec_kernel) { + // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! @@ -2635,7 +2638,7 @@ static enum ggml_status ggml_metal_graph_compute( //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { @@ -2926,7 +2929,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - ggml_backend_metal_log_allocated_size(device, size_aligned); + //ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } From 1a88565b4489381923aa0c9a6741badfb6766b23 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 15:52:49 +0300 Subject: [PATCH 138/180] metal : clean-up kernel code --- ggml-metal.metal | 142 ++++++++++++++--------------------------------- 1 file changed, 43 insertions(+), 99 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 1ed5632b4e0ec..32cbef9dca103 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2121,7 +2121,7 @@ typedef void (flash_attn_ext_f16_t)( ushort sgitg[[simdgroup_index_in_threadgroup]]); // ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2178,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[Q8][D8]; + simdgroup_half8x8 lo[D8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { @@ -2194,10 +2194,8 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - lo[j][i] = make_filled_simdgroup_matrix(0.0h); - } + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } // zero out shared memory SH @@ -2229,20 +2227,18 @@ kernel void kernel_flash_attn_ext_f16( const short rv3 = ne03/ne23; // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[Q8][D8]; + simdgroup_half8x8 mq[D8]; - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); - } + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); } // pointer to the mask @@ -2262,10 +2258,7 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk[Q8]; - for (short j = 0; j < Q8; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); - } + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2273,19 +2266,15 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (short j = 0; j < Q8; ++j) { - simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); - } + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } // mqk = mqk*scale + mask - for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); - simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); - } + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); } } @@ -2293,7 +2282,7 @@ kernel void kernel_flash_attn_ext_f16( float smax = -INFINITY; // online softmax - if (C == 32) { + { float ms[Q]; for (short j = 0; j < Q; ++j) { @@ -2314,45 +2303,6 @@ kernel void kernel_flash_attn_ext_f16( ss[j*TF + p] = vs; } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; - } - } else { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const float m = M[j]; - - for (short p = tiisg; p < C; p += NW) { - const float s = ss[j*TF + p]; - - smax = max(smax, s); - M[j] = max(M[j], s); - } - - smax = simd_max(smax); - M[j] = simd_max(M[j]); - - ms[j] = exp(m - M[j]); - - // local sum - float ls = 0.0h; - - for (short p = tiisg; p < C; p += NW) { - const float s = ss[j*TF + p]; - - const float vs = exp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } - - S[j] = S[j]*ms[j] + simd_sum(ls); - } - // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { ss[tiisg*TF + C + tiisg] = ms[tiisg]; @@ -2365,12 +2315,12 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - for (short j = 0; j < Q8; ++j) { + { simdgroup_float8x8 mm; - simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); + simdgroup_load(mm, ss + C, TF, 0, false); for (short i = 0; i < D8; ++i) { - simdgroup_multiply(lo[j][i], mm, lo[j][i]); + simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -2383,12 +2333,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - for (short j = 0; j < Q8; ++j) { - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); - simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); - } + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); } } } @@ -2412,10 +2360,8 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - } + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } @@ -2447,19 +2393,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short j = 0; j < Q8; ++j) { + { simdgroup_half8x8 t; simdgroup_float8x8 ms0; simdgroup_float8x8 ms1; - simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); - simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); + simdgroup_load (t, sq + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); } } } @@ -2467,10 +2413,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - } + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } @@ -2488,14 +2432,14 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, device const char * k, @@ -2539,7 +2483,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const short D4 = D/4; const short NW = N_SIMDWIDTH; - const short SH = (C + 1); // shared memory per simdgroup in (half) + const short SH = (C + Q); // shared memory per simdgroup in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half) @@ -2763,8 +2707,8 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; kernel void kernel_cpy_f16_f16( device const half * src0, From bc346166f96b13cf62291fe8a0212c98e561645c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 17:24:52 +0300 Subject: [PATCH 139/180] metal : minor --- ggml-metal.m | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 68eb49d5b1baa..aa22a24f01c38 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -451,7 +451,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ } /* - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -461,7 +461,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ @@ -470,7 +470,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ return NULL; \ } \ } else { \ - GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ + GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } // simd_sum and simd_max requires MTLGPUFamilyApple7 From 52945429eb43be1c8b89a77719513cd088e29586 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 17:38:28 +0300 Subject: [PATCH 140/180] tests : remove benchmarks ggml-ci --- tests/test-backend-ops.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1cc5e53bd2442..2317b8b7e1fab 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -15,6 +15,9 @@ #include #include +// TODO: remove before merging +//#define TMP_ATTN_BENCH + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { // static RNG initialization (revisit if n_threads stops being constant) static const size_t n_threads = std::thread::hardware_concurrency(); @@ -571,7 +574,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 0 +#ifndef TMP_ATTN_BENCH for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -1513,8 +1516,8 @@ struct test_flash_attn_ext : public test_case { } }; +#ifdef TMP_ATTN_BENCH // ATTN -// TODO: this is temporary until the FA branch is merged struct test_attn : public test_case { const int64_t hs; // head size const int64_t nh; // num heads @@ -1555,6 +1558,7 @@ struct test_attn : public test_case { return cur; } }; +#endif enum llm_norm_type { LLM_NORM, @@ -2220,7 +2224,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); -#if 1 +#ifdef TMP_ATTN_BENCH for (int hs : { 128, 256, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -2232,11 +2236,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } #else - for (int hs : { 128, }) { + for (int hs : { 64, 80, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, 512 }) { - test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + for (int nb : { 1, 2, 4, 8, }) { test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); } } From 3badef1fe143764c1298a45cd0986a737eba0a8d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 17:45:08 +0300 Subject: [PATCH 141/180] ggml : fix avx512 const correctness ggml-ci --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 76ca79e660aa0..1d88e0da246e9 100644 --- a/ggml.c +++ b/ggml.c @@ -1058,7 +1058,7 @@ do { \ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F // so F16C guard isn't required -#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x))) +#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) From 871fcb6e101dc3fdc92fce273a7c932d16b72d8a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 18:03:56 +0300 Subject: [PATCH 142/180] ggml : fix soft_max with bias on CPU ggml-ci --- ggml.c | 4 ++-- tests/test-backend-ops.cpp | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index 1d88e0da246e9..41557ab6766c5 100644 --- a/ggml.c +++ b/ggml.c @@ -12410,7 +12410,7 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data; for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); @@ -12433,7 +12433,7 @@ static void ggml_compute_forward_soft_max_f32( const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2317b8b7e1fab..ce39dadbb61e3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1103,6 +1103,12 @@ struct test_soft_max : public test_case { return VARS_TO_STR5(type, ne, mask, scale, max_bias); } + // the 1024 test with bias occasionally fails: + // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL + virtual double max_nmse_err() override { + return 1e-6; + } + test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, bool mask = false, @@ -2180,7 +2186,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); } } From a39217d4285b44c1b916c949ef6581e82f3c3ef3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 12:50:10 +0300 Subject: [PATCH 143/180] common : print --flash-attn in help --- common/common.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/common.cpp b/common/common.cpp index fbff8cf13effc..a29c451aa94fc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1482,6 +1482,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); if (llama_supports_mlock()) { From cb76d747d166dd9bbd028d666b1e8e53fe10efa0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 12:50:26 +0300 Subject: [PATCH 144/180] ggml : fix num dimensions in ggml_flash_attn_ext --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 41557ab6766c5..b1c76e6789749 100644 --- a/ggml.c +++ b/ggml.c @@ -6321,7 +6321,7 @@ struct ggml_tensor * ggml_flash_attn_ext( // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); From c11d05fec03a9190cabe8f1afb58a381811d5e21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 12:50:41 +0300 Subject: [PATCH 145/180] llama : force disable flash attention for incompatible models --- llama.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index a7ce50dd30efa..a4b00e7ff3ddf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1823,7 +1823,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; + bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -6311,6 +6311,8 @@ static struct ggml_tensor * llm_build_kqv( GGML_UNUSED(model); GGML_UNUSED(n_ctx); + // note: if this assert triggers, then some check has failed earlier + // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); // split cached v into n_head heads (not transposed) @@ -15114,6 +15116,16 @@ struct llama_context * llama_new_context_with_model( } } + if (cparams.flash_attn && hparams.need_kq_pos) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); + cparams.flash_attn = false; + } + + if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + cparams.flash_attn = false; + } + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } From f725ca90fb77f32e52ee2c204708560c952fdf78 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 13:46:23 +0300 Subject: [PATCH 146/180] ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci --- ggml-cuda/softmax.cu | 46 +++++++++++++++++++++++++++++--------- ggml-metal.m | 29 ++++++++++++++++++------ ggml-metal.metal | 18 +++++++++++---- ggml.c | 38 +++++++++++++++++++++++-------- llama.cpp | 4 ++-- tests/test-backend-ops.cpp | 4 ++-- 6 files changed, 105 insertions(+), 34 deletions(-) diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 8f6dca4d0f9bf..c0557db78df8d 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,17 @@ #include "softmax.cuh" -template -static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __device__ __forceinline__ float t2f32(T val) { + return (float) val; +} + +template <> +__device__ float __forceinline__ t2f32(half val) { + return __half2float(val); +} + +template +static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha } } -static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const half * mask, const half * p void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; - const half * src1_d = src1 ? (const half *)src1->data : nullptr; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - half * src2_dd = nullptr; + void * src2_d = nullptr; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (half *)src2->data; + src2_d = (void *)src2->data; } - soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + const half * src2_dd = (const half *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } else { + const float * src1_dd = (const float *)src1_d; + const float * src2_dd = (const float *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } } diff --git a/ggml-metal.m b/ggml-metal.m index aa22a24f01c38..1903791f1f9e3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -46,8 +46,10 @@ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -492,8 +494,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); @@ -1346,22 +1350,33 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); int nth = 32; // SIMD width id pipeline = nil; + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } } float scale; diff --git a/ggml-metal.metal b/ggml-metal.metal index 32cbef9dca103..3d4276ae02b9e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -352,6 +352,7 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( device const char * src0, device const char * src1, @@ -376,8 +377,8 @@ kernel void kernel_soft_max( const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; - device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); float slope = 0.0f; @@ -456,6 +457,7 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( device const char * src0, device const char * src1, @@ -480,8 +482,8 @@ kernel void kernel_soft_max_4( const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; - device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; float slope = 0.0f; @@ -562,6 +564,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, diff --git a/ggml.c b/ggml.c index b1c76e6789749..bc19f35bff84f 100644 --- a/ggml.c +++ b/ggml.c @@ -5473,7 +5473,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); GGML_ASSERT(mask->ne[1] >= a->ne[1]); @@ -5481,10 +5481,14 @@ static struct ggml_tensor * ggml_soft_max_impl( if (pos) { GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F16); + GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); GGML_ASSERT(pos->ne[0] == a->ne[0]); } + if (pos && mask) { + GGML_ASSERT(pos->type == mask->type); + } + if (max_bias > 0.0f) { GGML_ASSERT(pos); } @@ -12410,20 +12414,30 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data; + ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; + float * pos_f32 = src2 ? (float *) src2->data : src0->data; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - for (int i = 0; i < nc; ++i) { - wp[i] += GGML_FP16_TO_FP32(mp[i]); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += mp_f32[i]; + } } } @@ -12432,8 +12446,14 @@ static void ggml_compute_forward_soft_max_f32( const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]); + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*pos_f32[i]; + } } } diff --git a/llama.cpp b/llama.cpp index a4b00e7ff3ddf..26802d96a6752 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6710,14 +6710,14 @@ struct llm_build_context { } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } struct ggml_tensor * build_inp_KQ_pos() { lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; } struct ggml_tensor * build_inp_mean() { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce39dadbb61e3..d044a6ea02481 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1120,11 +1120,11 @@ struct test_soft_max : public test_case { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); } ggml_tensor * pos = nullptr; if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]); + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); return out; From 5408d55506186b8e56f8a6a748688847ef0ebb7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 19:12:06 +0300 Subject: [PATCH 147/180] cuda : uint -> uint32_t --- ggml-cuda/common.cuh | 6 +++--- ggml-cuda/fattn.cu | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index ac6de643d668e..e82d63e4a0abc 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -307,9 +307,9 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { } #if CUDART_VERSION < 12000 -static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { - const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); - const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); return mask_low | mask_high; } #endif // CUDART_VERSION < 12000 diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 4cf2907e8d10c..2077da53dc68f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -418,8 +418,8 @@ static __global__ void flash_attn_ext_f16( KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; KQ_max_h2[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -429,8 +429,8 @@ static __global__ void flash_attn_ext_f16( const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } @@ -602,8 +602,8 @@ static __global__ void flash_attn_combine_results( for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; const float KQ_max_scale = expf(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint *) &KQ_max_scale) &= ftz_mask; + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; From c70bfd7bcb5b218bea00cddee0dfca0a7d4e4c7f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 20:31:23 +0300 Subject: [PATCH 148/180] cuda : "constexpr dim3" -> "const dim3" ggml-ci --- ggml-cuda/fattn.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 2077da53dc68f..aaaea2f0701d8 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -652,7 +652,7 @@ template void launch_fattn_vec_f16( } constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; - constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -680,9 +680,9 @@ template void launch_fattn_vec_f16( return; } - constexpr dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; flash_attn_combine_results <<>> @@ -703,7 +703,7 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -731,9 +731,9 @@ template ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; flash_attn_combine_results <<>> From bf59e42cb4862a03dde18d4f5d712ba3719a02f4 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Apr 2024 11:44:58 +0900 Subject: [PATCH 149/180] added implementation of DRY sampler --- common/sampling.cpp | 17 ++++++++++++- common/sampling.h | 6 ++++- llama.cpp | 58 +++++++++++++++++++++++++++++++++++++++++++++ llama.h | 12 ++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 45d68b26c2b93..d2cbeb8cea25f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -256,13 +256,18 @@ static llama_token_data_array llama_sampling_prepare_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + // repetition penalties const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const bool penalize_nl = params.penalize_nl; + // DRY sampler parameters + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const int dry_allowed_length = params.dry_allowed_length; + auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -298,10 +303,20 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + // repetition penalties llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + // DRY penalties (multiplier > 0 means enabled) + if(dry_multiplier > 0.0f) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); + } + + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { diff --git a/common/sampling.h b/common/sampling.h index 639b819ab4fb2..644830348b42a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -38,7 +38,10 @@ typedef struct llama_sampling_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token + bool penalize_nl = false; // consider newlines as a repeatable token + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + int dry_allowed_length = 2; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -59,6 +62,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; + std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.cpp b/llama.cpp index a25d115c1d82a..fa08cd71c0856 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13044,6 +13044,64 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { + // loop through each candidate + for (size_t i = 0; i < candidates->size; ++i) { + + // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty + if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) { + continue; + } + + int max_match_length = 0; + + // loop through each previous token + for (size_t j = 0; j < last_token_size; ++j) { + // if the current candidate is the same as the previous token + if (candidates->data[i].id == last_tokens[j]) { + // greedily match sequence backwards starting from the current position with the end of prev + int match_length = 1; + + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our stored prev, break + if(j - match_length > 0) break; + + // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length) + // but let's check here to avoid the unexpected + if(last_token_size - match_length < 0) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[j - match_length]; + + // head token starts at the end of prev, going backwards by match length + auto head_token = last_tokens[last_token_size - match_length]; + + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; + + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } + + // update our max match length + max_match_length = std::max(max_match_length, match_length); + } + } + + // apply penalties + if(max_match_length > dry_allowed_length) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + } + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index 4effca42cc65d..050df6353677e 100644 --- a/llama.h +++ b/llama.h @@ -922,6 +922,18 @@ extern "C" { float p, size_t min_keep); + /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + int last_token_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const llama_token * seq_breakers, + int seq_breakers_size); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx, From 11487edd2ae506e50345f3aca79a44dbcd36f22c Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Apr 2024 11:56:47 +0900 Subject: [PATCH 150/180] fixed size_t/int comparison error --- llama.cpp | 8 ++++---- llama.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index fa08cd71c0856..ec818cbb606a9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13044,7 +13044,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) { // loop through each candidate for (size_t i = 0; i < candidates->size; ++i) { @@ -13056,7 +13056,7 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi int max_match_length = 0; // loop through each previous token - for (size_t j = 0; j < last_token_size; ++j) { + for (size_t j = 0; j < last_tokens_size; ++j) { // if the current candidate is the same as the previous token if (candidates->data[i].id == last_tokens[j]) { // greedily match sequence backwards starting from the current position with the end of prev @@ -13069,13 +13069,13 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length) // but let's check here to avoid the unexpected - if(last_token_size - match_length < 0) break; + if(last_tokens_size - match_length < 0) break; // compare token starts at our prev index, going backwards by match length auto compare_token = last_tokens[j - match_length]; // head token starts at the end of prev, going backwards by match length - auto head_token = last_tokens[last_token_size - match_length]; + auto head_token = last_tokens[last_tokens_size - match_length]; // if compare token is part of the sequence breakers, break out of the match if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) diff --git a/llama.h b/llama.h index 050df6353677e..df5fb20ddc22c 100644 --- a/llama.h +++ b/llama.h @@ -927,12 +927,12 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, - int last_token_size, + size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, - int seq_breakers_size); + size_t seq_breakers_size); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( From f9e416d3b7ecc3da7697df5b3a227b3b25a2bb34 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Apr 2024 12:02:46 +0900 Subject: [PATCH 151/180] removed unused parameters in llama_sample_dry --- common/sampling.cpp | 2 +- llama.cpp | 2 +- llama.h | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index d2cbeb8cea25f..0786c9c1d9a07 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -310,7 +310,7 @@ static llama_token_data_array llama_sampling_prepare_impl( // DRY penalties (multiplier > 0 means enabled) if(dry_multiplier > 0.0f) { - llama_sample_dry(ctx_main, &cur_p, + llama_sample_dry(&cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); diff --git a/llama.cpp b/llama.cpp index ec818cbb606a9..7f3a283a6f500 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13044,7 +13044,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) { +void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) { // loop through each candidate for (size_t i = 0; i < candidates->size; ++i) { diff --git a/llama.h b/llama.h index df5fb20ddc22c..6cb63ff6cd50a 100644 --- a/llama.h +++ b/llama.h @@ -924,7 +924,6 @@ extern "C" { /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 LLAMA_API void llama_sample_dry( - struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, From 62b9cb41d5ebec78ae2864c47f025d5c6ef4824e Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Apr 2024 12:03:40 +0900 Subject: [PATCH 152/180] removed trailing whitespace --- common/sampling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 0786c9c1d9a07..7bb9101a7e065 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -315,7 +315,7 @@ static llama_token_data_array llama_sampling_prepare_impl( penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); } - + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { From babff06d7bce0938cc9184afa65d8a65e26b02bf Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Apr 2024 13:28:06 +0900 Subject: [PATCH 153/180] removed unnecessary comparison --- llama.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7f3a283a6f500..32c67cb094141 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13067,9 +13067,8 @@ void llama_sample_dry(llama_token_data_array * candidates, const llama_token * l // if we have reached the start of our stored prev, break if(j - match_length > 0) break; - // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length) - // but let's check here to avoid the unexpected - if(last_tokens_size - match_length < 0) break; + // (last_tokens_size - match_length) is unsigned so will always be greater or equal to 0 + // so no need to check for index out of bound here // compare token starts at our prev index, going backwards by match length auto compare_token = last_tokens[j - match_length]; From 99b77600f199c6a4c0ca0d392e5c61bef6073915 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 23 Apr 2024 13:44:22 +0900 Subject: [PATCH 154/180] added dry sampler implementatin --- common/sampling.cpp | 17 ++++++++++++- common/sampling.h | 6 ++++- llama.cpp | 58 +++++++++++++++++++++++++++++++++++++++++++++ llama.h | 12 ++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 5ae941fc73af3..c7a2b7595200b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -260,13 +260,18 @@ static llama_token_data_array llama_sampling_prepare_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + // repetition penalties const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const bool penalize_nl = params.penalize_nl; + // DRY sampler parameters + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const int dry_allowed_length = params.dry_allowed_length; + auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -302,10 +307,20 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + // repetition penalties llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + // DRY penalties (multiplier > 0 means enabled) + if(dry_multiplier > 0.0f) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); + } + + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { diff --git a/common/sampling.h b/common/sampling.h index f07c673c05160..85bbd8d00b587 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -38,7 +38,10 @@ typedef struct llama_sampling_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token + bool penalize_nl = false; // consider newlines as a repeatable token + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + int dry_allowed_length = 2; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -59,6 +62,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; + std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.cpp b/llama.cpp index 7ebb91d72cd11..39c682a054382 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12832,6 +12832,64 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { + // loop through each candidate + for (size_t i = 0; i < candidates->size; ++i) { + + // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty + if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) { + continue; + } + + int max_match_length = 0; + + // loop through each previous token + for (size_t j = 0; j < last_token_size; ++j) { + // if the current candidate is the same as the previous token + if (candidates->data[i].id == last_tokens[j]) { + // greedily match sequence backwards starting from the current position with the end of prev + int match_length = 1; + + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our stored prev, break + if(j - match_length > 0) break; + + // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length) + // but let's check here to avoid the unexpected + if(last_token_size - match_length < 0) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[j - match_length]; + + // head token starts at the end of prev, going backwards by match length + auto head_token = last_tokens[last_token_size - match_length]; + + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; + + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } + + // update our max match length + max_match_length = std::max(max_match_length, match_length); + } + } + + // apply penalties + if(max_match_length > dry_allowed_length) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + } + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index 77c288eb25543..3ef8ce8e1b714 100644 --- a/llama.h +++ b/llama.h @@ -918,6 +918,18 @@ extern "C" { float p, size_t min_keep); + /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + int last_token_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const llama_token * seq_breakers, + int seq_breakers_size); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx, From c129369702655f0bafa06426fe8179c2c28d63ea Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 21:42:43 +0300 Subject: [PATCH 155/180] cuda : try to fix __hgt2_mask ggml-ci --- ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index e82d63e4a0abc..ca0d85ae9ab70 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -308,8 +308,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #if CUDART_VERSION < 12000 static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { - const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); - const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); + const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); + const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } #endif // CUDART_VERSION < 12000 From 3864eea4cbf4d224d8ce798e7c537838048f540a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 10:01:49 +0300 Subject: [PATCH 156/180] ggml : add TODO's for F16/F32 mask/pos support in other backends --- ggml-kompute.cpp | 7 +++++++ ggml-sycl.cpp | 6 +++++- ggml-vulkan.cpp | 5 +++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6fd476..9a469821d8042 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml for (int i = node_start; i < node_end; ++i) { struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); @@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { float scale; memcpy(&scale, dst->op_params, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); + GGML_ASSERT(src2 == nullptr); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a9b310243f04f..f8ed55eb82259 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14738,7 +14738,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + const ggml_tensor * src2 = dst->src[2]; + +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14754,7 +14759,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, float * src2_dd = nullptr; sycl_pool_alloc src2_f; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 1736ab7361c27..f712cdd5a900e 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } From 78d363b0d4b776a0ead4246d6f3392df04a7044b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:15:13 +0300 Subject: [PATCH 157/180] llama : replace bool need_kq_pos with use_alibi --- llama.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 26802d96a6752..83e0c2ef11831 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1823,7 +1823,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models + bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -4104,7 +4104,7 @@ static void llm_load_hparams( model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); @@ -6269,7 +6269,6 @@ static struct ggml_tensor * llm_build_moe_ffn( return moe_out; } -// if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, @@ -6359,7 +6358,7 @@ static struct ggml_tensor * llm_build_kqv( #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { + if (hparams.use_alibi) { kq = ggml_scale(ctx, kq, kq_scale); cb(kq, "kq_scaled", il); @@ -10714,7 +10713,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { + // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch + // this allows to process multiple sequences in parallel with ALiBi-based models + if (hparams.use_alibi) { const int64_t n_kv = kv_self.n; GGML_ASSERT(lctx.inp_KQ_pos); @@ -15116,7 +15117,7 @@ struct llama_context * llama_new_context_with_model( } } - if (cparams.flash_attn && hparams.need_kq_pos) { + if (cparams.flash_attn && hparams.use_alibi) { LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); cparams.flash_attn = false; } From 19e8982f51c5bea11735f688092627f1461d8759 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:24:28 +0300 Subject: [PATCH 158/180] llama : prep ALiBi support for BERT models ggml-ci --- ggml.c | 1 + llama.cpp | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index bc19f35bff84f..469a0e0d9afd6 100644 --- a/ggml.c +++ b/ggml.c @@ -5476,6 +5476,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } diff --git a/llama.cpp b/llama.cpp index 83e0c2ef11831..4b38f5870adb0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6712,8 +6712,14 @@ struct llm_build_context { return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + struct ggml_tensor * build_inp_KQ_pos(bool causal = true) { + if (causal) { + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + } else { + // TODO: this will be needed for ALiBi-based BERT models + // https://github.com/ggerganov/llama.cpp/pull/6826 + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); + } cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; From 56657e52e5f2b0fdcd414f99837b2e67efcf824c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:30:37 +0300 Subject: [PATCH 159/180] llama : fix n_batch requirements ggml-ci --- llama.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4b38f5870adb0..fcd15501e416f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15064,10 +15064,6 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask - // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) - cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); - cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; @@ -15086,12 +15082,20 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : From d228bf8552f5a6afa0f4c523c0da4bc00312b791 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:32:11 +0300 Subject: [PATCH 160/180] cont --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index fcd15501e416f..a3624544cf326 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15092,7 +15092,7 @@ struct llama_context * llama_new_context_with_model( // ref: https://github.com/ggerganov/llama.cpp/pull/5021 if (cparams.n_batch < GGML_KQ_MASK_PAD) { LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); - cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_batch = GGML_KQ_MASK_PAD; } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); From 751591d52074d6be53feed6e19211932e4520159 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 18:16:25 +0300 Subject: [PATCH 161/180] server : add help for --flash-attn arg --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f1754b60b7fe8..2cf59fbe0d66b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2357,6 +2357,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n"); From d8b47dafe1f195e1a5835f87b545073b93828a61 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 24 Apr 2024 12:43:57 +0900 Subject: [PATCH 162/180] refactored sample_dry to be more efficient and closer to original implementation --- llama.cpp | 96 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 37 deletions(-) diff --git a/llama.cpp b/llama.cpp index 32c67cb094141..efad3c12e74ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13045,58 +13045,80 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) { - // loop through each candidate - for (size_t i = 0; i < candidates->size; ++i) { + // sanity check + GGML_ASSERT(last_tokens_size > 0); + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; - // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty - if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) { + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token if it's not the same as the last token + if(last_tokens[i] != last_token) { continue; } - int max_match_length = 0; + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; - // loop through each previous token - for (size_t j = 0; j < last_tokens_size; ++j) { - // if the current candidate is the same as the previous token - if (candidates->data[i].id == last_tokens[j]) { - // greedily match sequence backwards starting from the current position with the end of prev - int match_length = 1; + // try to extend the match backwards (match length starts a 1 because last token is already matched) + size_t match_length = 1; - // loop through the previous tokens - for(;; match_length++) { - // if we have reached the start of our stored prev, break - if(j - match_length > 0) break; + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our last tokens, break + if(i < match_length) break; - // (last_tokens_size - match_length) is unsigned so will always be greater or equal to 0 - // so no need to check for index out of bound here + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; - // compare token starts at our prev index, going backwards by match length - auto compare_token = last_tokens[j - match_length]; + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; - // head token starts at the end of prev, going backwards by match length - auto head_token = last_tokens[last_tokens_size - match_length]; + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; - // if compare token is part of the sequence breakers, break out of the match - if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) - break; + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } - // break out of the match if any tokens don't match - if(compare_token != head_token) - break; - } + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); - // update our max match length - max_match_length = std::max(max_match_length, match_length); - } + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; - // apply penalties - if(max_match_length > dry_allowed_length) { - // calculate the penalty - float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length); + // if the match length is greater than our allowed length in config, we apply penalities + if(match_length > dry_allowed_length) { - // apply the dry penalty - candidates->data[i].logit -= penalty; + // find our next token in the candidates->data + size_t i = 0; + for (; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } } } } From a31ce1a182664b6ef83e3bd015e358dca4251be4 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 24 Apr 2024 13:40:09 +0900 Subject: [PATCH 163/180] added cmd to main for dry sampler --- common/common.cpp | 28 ++++++++++++++++++++++++++++ common/sampling.cpp | 4 ++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 06f252ea6914b..e3afa23f3e17a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -582,6 +582,30 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.penalty_present = std::stof(argv[i]); return true; } + if (arg == "--dry-multiplier") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.dry_multiplier = std::stof(argv[i]); + return true; + } + if (arg == "--dry-base") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.dry_base = std::stoi(argv[i]); + return true; + } + if (arg == "--dry-allowed-length") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.dry_allowed_length = std::stoi(argv[i]); + return true; + } if (arg == "--dynatemp-range") { if (++i >= argc) { invalid_param = true; @@ -1425,6 +1449,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); + printf(" --dry-multiplier N DRY sampler multiplier (default: %.1f, 0.0 = disabled)\n", (double)sparams.dry_multiplier); + printf(" --dry-base N DRY sampler base (default: %.1f)\n", (double)sparams.dry_base); + printf(" --dry-allowed-length N\n"); + printf(" DRY sampler allowed length (default: %d)\n", sparams.dry_allowed_length); printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range); printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent); printf(" --mirostat N use Mirostat sampling.\n"); diff --git a/common/sampling.cpp b/common/sampling.cpp index 7bb9101a7e065..8e107ea31bda7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,10 +99,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, dry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, - params.mirostat, params.mirostat_eta, params.mirostat_tau); + params.mirostat, params.mirostat_eta, params.mirostat_tau, params.dry_multiplier, params.dry_base, params.dry_allowed_length); return std::string(result); } From 7e088852dd8a39fccda85b397325029f777c5ddd Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 24 Apr 2024 13:47:27 +0900 Subject: [PATCH 164/180] skip dry sampler if last token is part of sequence breakers --- llama.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llama.cpp b/llama.cpp index efad3c12e74ea..a3042afa6f346 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13051,6 +13051,11 @@ void llama_sample_dry(llama_token_data_array * candidates, const llama_token * l // get the last token auto last_token = last_tokens[last_tokens_size - 1]; + // if last token is part of the sequence breakers, skip whole sampler + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + return; + } + // create an unordered map of "next tokens" <-> max match length std::unordered_map match_lengths; From 75c37ed8171b4d59dd15474049d9867fc76641cc Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 24 Apr 2024 13:51:41 +0900 Subject: [PATCH 165/180] fixed bug in dry sampler --- llama.cpp | 104 ++++++++++++++++++++++++++++++++++-------------------- llama.h | 2 +- 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/llama.cpp b/llama.cpp index 39c682a054382..7f67b2b5b8771 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12832,60 +12832,86 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { - // loop through each candidate - for (size_t i = 0; i < candidates->size; ++i) { +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { + // sanity check + GGML_ASSERT(last_tokens_size > 0); + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + return; + } - // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty - if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) { + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token if it's not the same as the last token + if(last_tokens[i] != last_token) { continue; } - int max_match_length = 0; + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; - // loop through each previous token - for (size_t j = 0; j < last_token_size; ++j) { - // if the current candidate is the same as the previous token - if (candidates->data[i].id == last_tokens[j]) { - // greedily match sequence backwards starting from the current position with the end of prev - int match_length = 1; + // try to extend the match backwards (match length starts a 1 because last token is already matched) + size_t match_length = 1; - // loop through the previous tokens - for(;; match_length++) { - // if we have reached the start of our stored prev, break - if(j - match_length > 0) break; + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our last tokens, break + if(i < match_length) break; - // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length) - // but let's check here to avoid the unexpected - if(last_token_size - match_length < 0) break; + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; - // compare token starts at our prev index, going backwards by match length - auto compare_token = last_tokens[j - match_length]; + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; - // head token starts at the end of prev, going backwards by match length - auto head_token = last_tokens[last_token_size - match_length]; + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; - // if compare token is part of the sequence breakers, break out of the match - if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) - break; + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } - // break out of the match if any tokens don't match - if(compare_token != head_token) - break; - } + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); - // update our max match length - max_match_length = std::max(max_match_length, match_length); - } + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); } + } - // apply penalties - if(max_match_length > dry_allowed_length) { - // calculate the penalty - float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length); + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; - // apply the dry penalty - candidates->data[i].logit -= penalty; + // if the match length is greater than our allowed length in config, we apply penalities + if(match_length > dry_allowed_length) { + + // find our next token in the candidates->data + size_t i = 0; + for (; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } } } } diff --git a/llama.h b/llama.h index 3ef8ce8e1b714..847d9668183d6 100644 --- a/llama.h +++ b/llama.h @@ -923,7 +923,7 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, - int last_token_size, + int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, From ce281b904c0ed97f7f7c685a2c7bf8dbaa6f8293 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Apr 2024 16:48:10 +0300 Subject: [PATCH 166/180] llama : disable FA for AMD --- ggml-cuda/common.cuh | 4 ++-- ggml-cuda/fattn.cu | 3 +++ llama.cpp | 7 +++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index ca0d85ae9ab70..156eba6d1ef74 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -399,8 +399,8 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL -#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ - defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + +#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index aaaea2f0701d8..df1e80068b334 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -2,7 +2,10 @@ #include "fattn.cuh" #include + +#if FP16_MMA_AVAILABLE #include +#endif #define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. diff --git a/llama.cpp b/llama.cpp index 11a1aa3a44c26..f00190a77bb79 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15357,6 +15357,13 @@ struct llama_context * llama_new_context_with_model( cparams.flash_attn = false; } +#ifdef GGML_USE_HIPBLAS + if (cparams.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); + cparams.flash_attn = false; + } +#endif + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } From a11bf2dbcd6226888826c207909f982caf2bac90 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 25 Apr 2024 16:26:06 +0900 Subject: [PATCH 167/180] fixed small bug --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4b28bc630b6f4..e6ab4468560d4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6391,6 +6391,9 @@ static struct ggml_tensor * llm_build_kqv( cb(k, "k", il); struct ggml_tensor * cur; + + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs @@ -6423,9 +6426,6 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 From ff2c64a9f4190cd5da3c59f65d8417299c72a336 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 15:51:46 +0300 Subject: [PATCH 168/180] tests : remove TMP_ATTN_BENCH ggml-ci --- tests/test-backend-ops.cpp | 70 -------------------------------------- 1 file changed, 70 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d044a6ea02481..b27c1291e4088 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -15,9 +15,6 @@ #include #include -// TODO: remove before merging -//#define TMP_ATTN_BENCH - static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { // static RNG initialization (revisit if n_threads stops being constant) static const size_t n_threads = std::thread::hardware_concurrency(); @@ -574,19 +571,9 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#ifndef TMP_ATTN_BENCH for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } -#else - int n_nodes = gf->n_nodes; - n_runs = 1000; - for (int i = 1; i < n_runs; i++) { - for (int j = 0; j < n_nodes; j++) { - gf->nodes[gf->n_nodes++] = gf->nodes[j]; - } - } -#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -1522,50 +1509,6 @@ struct test_flash_attn_ext : public test_case { } }; -#ifdef TMP_ATTN_BENCH -// ATTN -struct test_attn : public test_case { - const int64_t hs; // head size - const int64_t nh; // num heads - const int64_t kv; // kv size - const int64_t nb; // batch size - - std::string op_desc(ggml_tensor * t) override { - return "ATTN"; - - GGML_UNUSED(t); - } - - std::string vars() override { - return VARS_TO_STR4(hs, nh, kv, nb); - } - - double max_nmse_err() override { - return 5e-4; - } - - test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : hs(hs), nh(nh), kv(kv), nb(nb) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); - ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); - - struct ggml_tensor * cur; - - cur = ggml_mul_mat (ctx, k, q); - cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); - cur = ggml_mul_mat (ctx, v, cur); - cur = ggml_permute (ctx, cur, 0, 2, 1, 3); - cur = ggml_cont_2d (ctx, cur, hs*nh, nb); - - return cur; - } -}; -#endif - enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2230,18 +2173,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); -#ifdef TMP_ATTN_BENCH - for (int hs : { 128, 256, 64, 80, }) { - for (int nh : { 32, }) { - for (int kv : { 512, 1024, 2048, 4096, }) { - for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { - test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); - } - } - } - } -#else for (int hs : { 64, 80, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { @@ -2251,7 +2182,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } -#endif // these tests are disabled to save execution time, but they can be handy for debugging #if 0 From 1fd5bc3d5e4ebfad3499d59dfee60202a4b7bb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 18:18:13 +0300 Subject: [PATCH 169/180] llama : support save/load state with FA enabled ggml-ci --- ci/run.sh | 3 ++- llama.cpp | 16 ++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index da05f0d48802a..56beaedea0408 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -517,7 +517,8 @@ function gg_run_open_llama_7b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -fa ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" diff --git a/llama.cpp b/llama.cpp index 718d5cccb6a9c..65ac6f6f26a8c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2036,8 +2036,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2335,11 +2335,14 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, - const llama_model & model, + const llama_context * ctx, ggml_type type_k, ggml_type type_v, uint32_t kv_size, bool offload) { + const llama_model & model = ctx->model; + const llama_cparams & cparams = ctx->cparams; + const struct llama_hparams & hparams = model.hparams; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); @@ -2350,6 +2353,7 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.v_trans = !cparams.flash_attn; // TODO: support mixed reccurent Transformer architectues // NOTE: (!a || b) is a logical implication (a -> b) @@ -15550,7 +15554,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -16330,7 +16334,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16486,7 +16490,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); From ac1c6d91de2d77a39f67ce55fd1ef6772d7e4a4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:03:59 +0300 Subject: [PATCH 170/180] ci : add CUDA save-load-state tests ggml-ci --- ci/run.sh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index 56beaedea0408..fda1169b2c2f2 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -517,8 +518,10 @@ function gg_run_open_llama_7b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/save-load-state --model -fa ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -fa -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -fa -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" From c225609f1003612798698bea8b1a4d6e8d0c3da8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:37:27 +0300 Subject: [PATCH 171/180] llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci --- llama.cpp | 147 +++++++++++++++++++++++++++++++++++++----------------- llama.h | 2 +- 2 files changed, 103 insertions(+), 46 deletions(-) diff --git a/llama.cpp b/llama.cpp index eaf1d60b43bca..6db0b95710c0d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2566,6 +2566,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { } cache.head = 0; cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf, 0); + } } static bool llama_kv_cache_seq_rm( @@ -16483,6 +16487,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } + llama_kv_cache_clear(ctx); + if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; @@ -16516,8 +16522,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; ctx->kv_self.used = kv_used; @@ -16777,28 +16781,48 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } - // For the values, they are transposed, so we also need the element size and get the element ranges from each row - const uint32_t kv_size = kv_self.size; - for (int il = 0; il < (int)n_layer; ++il) { - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - data_ctx.write(&v_type_i, sizeof(v_type_i)); + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Write key type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write element size - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - data_ctx.write(&v_size_el, sizeof(v_size_el)); + // Write row size of key + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + data_ctx.write(&v_size_row, sizeof(v_size_row)); - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - tmp_buf.resize(range_size * v_size_el); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + tmp_buf.resize(range_size * v_size_row); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row); data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } + } else { + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } } return data_ctx.get_size_written(); @@ -16923,41 +16947,74 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } - // For each layer, read the values for each cell (transposed) - for (int il = 0; il < (int)n_layer; ++il) { - // Read type of value - int32_t v_type_i_ref; - memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); - inp += sizeof(v_type_i_ref); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return 0; - } + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of key + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } - // Read element size of value - size_t v_size_el_ref; - memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); - inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); - return 0; - } + // Read row size of key + size_t v_size_row_ref; + memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); + inp += sizeof(v_size_row_ref); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + return 0; + } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); - inp += cell_count * v_size_el; + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); + inp += cell_count * v_size_row; + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); + return 0; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } } } } const size_t nread = inp - src; + return nread; } diff --git a/llama.h b/llama.h index 792ef74d364ca..bedbc7c2c6e30 100644 --- a/llama.h +++ b/llama.h @@ -526,7 +526,7 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache + // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); From bab346ba69de6117bc37a8604ffe95ecaed84664 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:45:36 +0300 Subject: [PATCH 172/180] llama : fix copy-paste errors, add TODO --- llama.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 6db0b95710c0d..ecdc4b7fcf4fc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16781,13 +16781,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } + // TODO: simplify, reduce copy-paste if (!kv_self.v_trans) { for (int il = 0; il < (int)n_layer; ++il) { - // Write key type + // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write row size of key + // Write row size of value const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); data_ctx.write(&v_size_row, sizeof(v_size_row)); @@ -16947,32 +16948,33 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } + // TODO: simplify, reduce copy-paste if (!kv_self.v_trans) { for (int il = 0; il < (int)n_layer; ++il) { - // Read type of key + // Read type of value int32_t v_type_i_ref; memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } - // Read row size of key + // Read row size of value size_t v_size_row_ref; memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); inp += sizeof(v_size_row_ref); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); return 0; } if (cell_count) { - // Read and set the keys for the whole cell range + // Read and set the values for the whole cell range ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); inp += cell_count * v_size_row; } From 0fc5c5eb74accef6a5904e4933e2d3b08e3cd34a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:53:57 +0300 Subject: [PATCH 173/180] llama : disallow incompatible states --- llama.cpp | 6 ++++++ llama.h | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ecdc4b7fcf4fc..8b258f988e514 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16323,11 +16323,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const uint32_t kv_size = kv_self.size; const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_used, sizeof(kv_used)); + data_ctx->write(&v_trans, sizeof(v_trans)); if (kv_buf_size) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); @@ -16473,11 +16475,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { uint32_t kv_head; uint32_t kv_size; uint32_t kv_used; + uint32_t v_trans; memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans); + + GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition if (kv_self.size != kv_size) { // the KV cache needs to be big enough to load all the KV cells from the saved state diff --git a/llama.h b/llama.h index bedbc7c2c6e30..9c7cdf99faa6d 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 5 +#define LLAMA_SESSION_VERSION 6 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 From 1e590ac3c97534ba0ff34388a30d2430a7684c10 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 20:06:23 +0300 Subject: [PATCH 174/180] llama : update llama_state_get_size after v_trans field --- llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama.cpp b/llama.cpp index 8b258f988e514..f41fba6c7f836 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16157,6 +16157,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); + const size_t s_v_trans = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; @@ -16174,6 +16175,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_head + s_kv_size + s_kv_used + + s_v_trans + s_kv + s_kv_cells ); From 4f4c0249bf31e3b9c161670231d32e60293a3314 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 20:29:25 +0300 Subject: [PATCH 175/180] metal : remove tmp log --- ggml-metal.m | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1903791f1f9e3..249b8312cdf63 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -463,9 +463,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ - GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ - (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ - (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ [metal_library release]; \ From 9e3876061c0734b6ec55326ea9d73f379594d9bd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 20:33:36 +0300 Subject: [PATCH 176/180] llama : add static reminder for llama_state_get_size --- llama.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama.cpp b/llama.cpp index f41fba6c7f836..3598921668767 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16180,6 +16180,9 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_cells ); + // on session change it is very likely that the state size has changed - so we need to update this function + static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + return s_total; } From 1b24b38bd5ae03a37b8989a1d708a86b956af8c3 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:26:17 +0900 Subject: [PATCH 177/180] synced dry implementation with llama.cpp PR --- common/sampling.cpp | 57 +++++++++++++++++++++++++-------------------- common/sampling.h | 7 +++--- llama.cpp | 40 +++++++++++++++++-------------- llama.h | 4 ++-- 4 files changed, 61 insertions(+), 47 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 289c0215bf541..aabeaa1d787de 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -279,9 +279,10 @@ static llama_token_data_array llama_sampling_prepare_impl( const bool penalize_nl = params.penalize_nl; // DRY sampler parameters - const float dry_multiplier = params.dry_multiplier; - const float dry_base = params.dry_base; - const int dry_allowed_length = params.dry_allowed_length; + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const uint32_t dry_allowed_length = params.dry_allowed_length; + const uint32_t dry_penalty_last_n = params.dry_penalty_last_n; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -312,35 +313,41 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - - // repetition penalties - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - - // DRY penalties (multiplier > 0 means enabled) - if(dry_multiplier > 0.0f) { - llama_sample_dry(&cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, - params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); - } - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; + // apply repetition penalties + { + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + // repetition penalties + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } } } } } + // apply DRY penalties + { + const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); + if (penalty_tokens_used_size) { + llama_sample_dry(&cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); + } + } + // apply grammar checks before sampling logic if (apply_grammar && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); diff --git a/common/sampling.h b/common/sampling.h index 7de531cdf71e1..8d00178cee5ac 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,9 +41,10 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context - float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f float dry_base = 1.75f; - int dry_allowed_length = 2; + uint32_t dry_allowed_length = 2; + uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -64,7 +65,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; - std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler + std::vector dry_seq_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.cpp b/llama.cpp index e6ab4468560d4..e65fd87c7870c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13307,15 +13307,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) { - // sanity check - GGML_ASSERT(last_tokens_size > 0); +void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { + // skip dry sampler if we don't have a previous token + if (last_tokens_size < 1) return; // get the last token auto last_token = last_tokens[last_tokens_size - 1]; // if last token is part of the sequence breakers, skip whole sampler - if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { return; } @@ -13324,21 +13324,26 @@ void llama_sample_dry(llama_token_data_array * candidates, const llama_token * l // loop through each previous token (exclude the last token) for (size_t i = 0; i < last_tokens_size - 1; ++i) { - // skip if the compare token if it's not the same as the last token - if(last_tokens[i] != last_token) { + // skip if the compare token is not the same as the last token + if (last_tokens[i] != last_token) { continue; } // get the next token (i + 1 is always less than last_tokens_size) auto next_token = last_tokens[i + 1]; - // try to extend the match backwards (match length starts a 1 because last token is already matched) + // if next token is part of the sequence breakers, skip + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + continue; + } + + // try to extend the match backwards (match length starts at 1 because last token is already matched) size_t match_length = 1; // loop through the previous tokens - for(;; match_length++) { + for (;; match_length++) { // if we have reached the start of our last tokens, break - if(i < match_length) break; + if (i < match_length) break; // compare token starts at our prev index, going backwards by match length auto compare_token = last_tokens[i - match_length]; @@ -13346,13 +13351,15 @@ void llama_sample_dry(llama_token_data_array * candidates, const llama_token * l // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself auto head_token = last_tokens[last_tokens_size - 1 - match_length]; - // if compare token is part of the sequence breakers, break out of the match - if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + // break out of the match if any tokens don't match + if (compare_token != head_token) { break; + } - // break out of the match if any tokens don't match - if(compare_token != head_token) + // if compare token is part of the sequence breakers, break out of the match + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { break; + } } // Check if the next token exists in the map @@ -13372,12 +13379,11 @@ void llama_sample_dry(llama_token_data_array * candidates, const llama_token * l auto next_token = pair.first; auto match_length = pair.second; - // if the match length is greater than our allowed length in config, we apply penalities - if(match_length > dry_allowed_length) { + // if the match length is greater than or equal to our allowed length in config, we apply penalities + if (match_length >= dry_allowed_length) { // find our next token in the candidates->data - size_t i = 0; - for (; i < candidates->size; ++i) { + for (size_t i = 0; i < candidates->size; ++i) { if (candidates->data[i].id == next_token) { // calculate the penalty float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); diff --git a/llama.h b/llama.h index 3e4a68465c091..bae887fa4e773 100644 --- a/llama.h +++ b/llama.h @@ -933,8 +933,8 @@ extern "C" { float dry_base, float dry_multiplier, int dry_allowed_length, - const llama_token * seq_breakers, - size_t seq_breakers_size); + const llama_token * dry_seq_breakers, + size_t dry_seq_breakers_size); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( From 908f057a3f1055b26de8603c06aa9cc78c9578ac Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 11:22:43 +0900 Subject: [PATCH 178/180] fixed linking error in llama.cpp --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 59a654fb34386..563f85d6c8d2f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6449,7 +6449,7 @@ static struct ggml_tensor * llm_build_kqv( cb(k, "k", il); struct ggml_tensor * cur; - + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); @@ -13369,7 +13369,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { +void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { // skip dry sampler if we don't have a previous token if (last_tokens_size < 1) return; @@ -18001,4 +18001,4 @@ static void llama_log_callback_default(ggml_log_level level, const char * text, (void) user_data; fputs(text, stderr); fflush(stderr); -} \ No newline at end of file +} From e180fcd3d53cee6ed5ac82473b45c468379a0687 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Apr 2024 11:04:32 +0300 Subject: [PATCH 179/180] metal : fix max nsg ggml-ci --- ggml-metal.m | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 249b8312cdf63..c6d580b8462d0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2643,13 +2643,25 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; From b7e0d148e00abd626e09927417c1ec762d2ee600 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 30 Apr 2024 12:13:50 +0000 Subject: [PATCH 180/180] remove printing of token IDs in main --- examples/main/main.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 37205c792cfd8..0f4e4fcbed4e1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -126,6 +126,8 @@ int main(int argc, char ** argv) { } llama_sampling_params & sparams = params.sparams; + sparams.dry_multiplier = 0.8f; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); LOG_TEE("Log start\n"); @@ -735,7 +737,7 @@ int main(int argc, char ** argv) { for (auto id : embd) { const std::string token_str = llama_token_to_piece(ctx, id); printf("%s", token_str.c_str()); - printf("(%d)", id); + //printf("(%d)", id); if (embd.size() > 1) { input_tokens.push_back(id);