diff --git a/CMakeLists.txt b/CMakeLists.txt index ffda74a700bef..aa8f33b40f325 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -313,6 +313,37 @@ if (LLAMA_CLBLAST) endif() endif() +if (LLAMA_VULKAN) + find_package(Vulkan COMPONENTS glslc) + if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + add_library(ggml-vulkan STATIC ggml-vulkan.cpp ggml-vulkan.h) + target_link_libraries(ggml-vulkan PUBLIC Vulkan::Vulkan) + + set(GGML_VULKAN_SHADERS matmul_f32 matmul_f16 f16_to_f32 dequant_q4_0) + + foreach(s IN LISTS GGML_VULKAN_SHADERS) + add_custom_command( + OUTPUT "vk_shaders/${s}.spv" + COMMAND "${Vulkan_GLSLC_EXECUTABLE}" + -fshader-stage=compute + --target-env=vulkan1.2 + "${CMAKE_CURRENT_SOURCE_DIR}/vk_shaders/${s}.glsl" + -o "${CMAKE_CURRENT_BINARY_DIR}/vk_shaders/${s}.spv" + DEPENDS "vk_shaders/${s}.glsl" + ) + target_sources(ggml-vulkan PRIVATE "vk_shaders/${s}.spv") + endforeach() + + add_compile_definitions(GGML_USE_VULKAN) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-vulkan) + else() + message(WARNING "Vulkan not found") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags