Skip to content

Commit

Permalink
updated vulkan includes and queue discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
l3utterfly committed Feb 21, 2024
1 parent 964ba95 commit 5acc721
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ if (LLAMA_VULKAN)
endif()

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} Vulkan::Vulkan)
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${VULKAN_PATH}/include)
else()
message(WARNING "Vulkan not found")
endif()
Expand Down
4 changes: 2 additions & 2 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);

// we need to use compute queue as the transfer queue
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer | vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer | vk::QueueFlagBits::eGraphics, vk::QueueFlagBits::eCompute, compute_queue_family_index, 1);
//const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);

const float priorities[] = { 1.0f, 1.0f };
Expand Down Expand Up @@ -5641,7 +5641,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, *(float *)tensor->op_params);
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, NULL, *(float *)tensor->op_params, 0.0f);
} else {
tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
}
Expand Down

0 comments on commit 5acc721

Please sign in to comment.