diff --git a/neural_speed/core/CMakeLists.txt b/neural_speed/core/CMakeLists.txt index 8008bb6f5..427bb8aa0 100644 --- a/neural_speed/core/CMakeLists.txt +++ b/neural_speed/core/CMakeLists.txt @@ -14,6 +14,8 @@ find_package(Threads REQUIRED) file(GLOB layers_srcs "layers/*.cpp") +file(GLOB test_srcs "layers/*test*.cpp") +list(REMOVE_ITEM layers_srcs ${test_srcs}) set(sources ne_layers.c ${layers_srcs}) add_shareable_library_w_warning(ne_layers "${sources}") @@ -37,7 +39,7 @@ endif() if (NS_BUILD_TESTS) -function(add_test_target src) +function(add_test_target src) # ARGN: additional source get_filename_component(test_target ${src} NAME_WE) get_filename_component(src_dir ${src} DIRECTORY) string(REGEX REPLACE [/\\] "_" src_dir ${src_dir}) @@ -45,7 +47,7 @@ function(add_test_target src) set (test_target "${src_dir}_${test_target}") endif() set (test_target "test_${test_target}") - add_executable_w_warning(${test_target} ${src}) + add_executable_w_warning(${test_target} ${src} ${ARGN}) target_compile_definitions(${test_target} PRIVATE NS_TESTS) target_compile_options(${test_target} PRIVATE -fsanitize=address) target_link_options(${test_target} PRIVATE -fsanitize=address) @@ -58,6 +60,6 @@ function(add_test_target src) set_tests_properties(${test_target} PROPERTIES LABELS "${src_dir}_test") endfunction() -add_test_target(layers/mha_dense.cpp) +add_test_target(layers/mha_dense.cpp layers/mha_dense_tests.cpp) endif() diff --git a/neural_speed/core/layers/mha_dense.cpp b/neural_speed/core/layers/mha_dense.cpp index 2247df104..ef6c216ae 100644 --- a/neural_speed/core/layers/mha_dense.cpp +++ b/neural_speed/core/layers/mha_dense.cpp @@ -21,12 +21,6 @@ #include #include -#ifdef NS_TESTS -#include -#include - -#include "layers/ne_test_layers_utils.hpp" -#endif #include "core/data_types.h" #include "mha_dense_wrapper.h" @@ -502,482 +496,3 @@ void bestla_fusion_attn_fp32_batch_cpy_v(const bestla_fusion_attn_fp32_batch_cpy return params->no_zeroing ? bestla_fusion_attn_fp32_batch_cpy_v_(params) : bestla_fusion_attn_fp32_batch_cpy_v_(params); } - -// #ifdef __GNUC__ -// #pragma GCC pop_options -// #endif - -#ifdef NS_TESTS -#define CheckISA(ISA) \ - (bestla::device::CpuDevice::getInstance()->ISA() || (printf("Wrong Device ISA: " #ISA "\n"), false)) - -namespace { -bool ret_ok = true; - -class TestMhaDese { - public: - TestMhaDese() { - printf("Test suit: %s\n", __FUNCTION__); - GetCPUDevice(); - ne_threading::get()->set_threads(std::min(_cd->getThreads(), omp_get_max_threads())); - -#if CompileFP16() - if (CheckISA(AMX_BF16)) { - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE); - ret_ok &= test_case({1, 1, 1, 32, 63, 63}, NE_ATTN_FLAG_NONE); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE); - ret_ok &= test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL); - - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({1, 1, 1, 256, 63, 63}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL, true); - - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({1, 1, 1, 256, 63, 63}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE, true); - ret_ok &= test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL, true); - } -#endif - - if (CheckISA(AMX_BF16)) { - const auto BA48b4a = ATTN_FWD_LAYOUT_NTILE48_ROWPACK4; - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE, false, BA48b4a); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE, false, BA48b4a); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE, false, BA48b4a); - ret_ok &= test_case({1, 1, 1, 256, 63, 63}, NE_ATTN_FLAG_NONE, false, BA48b4a); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE, false, BA48b4a); - ret_ok &= - test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL, false, BA48b4a); - } - - if (CheckISA(AMX_BF16)) { - const auto BA48b2a = ATTN_FWD_LAYOUT_NTILE48_ROWPACK2; - int flags = NE_ATTN_FLAG_NONE; - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({1, 1, 1, 256, 63, 63}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, flags, false, BA48b2a, 1e-3f); - - flags |= NE_ATTN_FLAG_IS_CAUSAL; - ret_ok &= test_case({1, 1, 1, 64, 64, 64}, flags, false, BA48b2a, 1e-3f); - } - - if (CheckISA(AVX512F)) { // PREFER_FP32 - const auto BA48b2a = ATTN_FWD_LAYOUT_NTILE48_ROWPACK2; - int flags = NE_ATTN_FLAG_PREFER_FP32; - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({1, 1, 1, 256, 63, 63}, flags, false, BA48b2a, 1e-3f); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, flags, false, BA48b2a, 1e-3f); - - flags |= NE_ATTN_FLAG_IS_CAUSAL; - ret_ok &= test_case({1, 1, 1, 64, 64, 64}, flags, false, BA48b2a, 1e-3f); - } - if (CheckISA(AVX2)) { // avx2 - const auto Ba24b = ATTN_FWD_LAYOUT_NTILE24_ROWPACK1; - int flags = NE_ATTN_FLAG_PREFER_FP32; - ret_ok &= test_case({1, 1, 1, 32, 128, 64}, flags, false, Ba24b, 1e-3f); - ret_ok &= test_case({2, 5, 5, 32, 64, 128}, flags, false, Ba24b, 1e-3f); - ret_ok &= test_case({2, 5, 5, 80, 128, 77}, flags, false, Ba24b, 1e-3f); - ret_ok &= test_case({1, 1, 1, 256, 63, 63}, flags, false, Ba24b, 1e-3f); - ret_ok &= test_case({3, 4, 4, 256, 1, 384}, flags, false, Ba24b, 1e-3f); - - flags |= NE_ATTN_FLAG_IS_CAUSAL; - ret_ok &= test_case({1, 1, 1, 64, 64, 64}, flags, false, Ba24b, 1e-3f); - } - - { // amxbf16 => avx2 fallback - int flags = NE_ATTN_FLAG_NONE; - ret_ok &= test_reorder_pipe({1, 1, 1, 32, 128, 64}, 64, flags); - ret_ok &= test_reorder_pipe({2, 5, 5, 32, 64, 128}, 256, flags); - ret_ok &= test_reorder_pipe({2, 5, 5, 80, 128, 77}, 256, flags); - ret_ok &= test_reorder_pipe({2, 5, 1, 80, 128, 77}, 256, flags); - ret_ok &= test_reorder_pipe({1, 1, 1, 256, 63, 63}, 256, flags); - ret_ok &= test_reorder_pipe({3, 4, 4, 256, 1, 384}, 384, flags); - ret_ok &= test_reorder_pipe({3, 4, 2, 256, 1, 384}, 384, flags); - flags |= NE_ATTN_FLAG_IS_CAUSAL; - ret_ok &= test_reorder_pipe({1, 1, 1, 64, 64, 64}, 128, flags); - flags |= NE_ATTN_FLAG_IS_ALIBI8; - ret_ok &= test_reorder_pipe({1, 8, 8, 64, 64, 64}, 128, flags); - } - printf("Test suit done: %s\n", __FUNCTION__); - } - - template - static constexpr float init_min_val = std::is_same::value ? -127.f - : std::is_same::value ? 0.f - : -1.f; - template - static constexpr float init_max_val = std::is_same::value ? 127.f - : std::is_same::value ? 255.f - : 1.f; - template - static constexpr float init_scale_val = 1.f / init_max_val; - -#ifdef _MSC_VER -#define __PRETTY_FUNCTION__ __FUNCSIG__ -#endif - - template - bool test_case(const attn_shape_t& s, ne_attn_flags_t flags, bool k_trans = false, - ATTN_FWD_LAYOUT kv_layout = ATTN_FWD_LAYOUT_PLAIN, float eps = 1e-2f) { - assert(kv_layout == ATTN_FWD_LAYOUT_PLAIN || !k_trans); - const auto batch_size = s.batch_size; - const auto head_num = s.head_num; - const auto heads_kv = s.heads_kv; - const auto head_size = s.head_size; - const auto sl_q = s.sl_q; - const auto sl_kv = s.sl_kv; - assert(("GQA not supported!", s.head_num == s.heads_kv)); - - const auto is_causal = flags & NE_ATTN_FLAG_IS_CAUSAL ? "maksed" : "unmask"; - const auto is_alibi8 = flags & NE_ATTN_FLAG_IS_ALIBI8 ? "alibi8" : ""; - const auto prefer_fp32 = flags & NE_ATTN_FLAG_PREFER_FP32 ? "FP32" : ""; - printf("\ntest_case: %s\t", __PRETTY_FUNCTION__); - printf("bs_%d hn_%d hkv_%d hs_%d sl_q_%d sk_kv_%d %s %s %s\n", batch_size, head_num, heads_kv, head_size, sl_q, - sl_kv, is_causal, is_alibi8, prefer_fp32); - - const auto NTILE = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 48 - : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 48 - : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 24 - : 0; - const auto ROWPACK = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 4 - : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 2 - : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 1 - : 0; - const auto ROWPAD = ROWPACK > 1 ? ROWPACK * 16 : 1; - const auto k_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, ROWPAD) : head_size; - const auto k_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, NTILE) : sl_kv; - const auto v_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, ROWPAD) : sl_kv; - const auto v_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, NTILE) : head_size; - - std::vector src_q(batch_size * head_num * sl_q * head_size); - std::vector src_k(batch_size * heads_kv * k_rows_pad * k_cols_pad); - std::vector src_v(batch_size * heads_kv * v_rows_pad * v_cols_pad); - std::vector dst(batch_size * head_num * sl_q * head_size); - std::vector ref(batch_size * head_num * sl_q * head_size); // reference result - std::vector tmp(bestla_fusion_attn_workspace_size(&s)); - - // init vector - static std::mt19937 rng(1); - std::uniform_int_distribution<> dist; - init_vector(&src_q, init_min_val, init_max_val, dist(rng)); - init_vector(&src_k, init_min_val, init_max_val, dist(rng)); - init_vector(&src_v, init_min_val, init_max_val, dist(rng)); - - // pad0 for padded layouts - if (kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 || kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 || - kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1) { -#pragma omp parallel for collapse(2) - for (int ibs = 0; ibs < batch_size; ++ibs) { - for (int ihn = 0; ihn < heads_kv; ++ihn) { - // K - const auto k_off = (ibs * heads_kv + ihn) * k_rows_pad * k_cols_pad; - for (int i = 0; i < k_rows_pad; ++i) { - for (int j = 0; j < k_cols_pad; ++j) { - if (i < head_size && j < sl_kv) continue; - - const auto j_remain = j % NTILE; - const auto j_block = j - j_remain; - const auto i_remain = i % ROWPACK; - const auto i_block = i - i_remain; - src_k[k_off + j_block * k_rows_pad + i_block * NTILE + j_remain * ROWPACK + i_remain] = K_T(0); - } - } - // V - const auto v_off = (ibs * heads_kv + ihn) * v_rows_pad * v_cols_pad; - for (int i = 0; i < v_rows_pad; ++i) { - for (int j = 0; j < v_cols_pad; ++j) { - if (i < sl_kv && j < head_size) continue; - - const auto j_remain = j % NTILE; - const auto j_block = j - j_remain; - const auto i_remain = i % ROWPACK; - const auto i_block = i - i_remain; - src_v[v_off + j_block * v_rows_pad + i_block * NTILE + j_remain * ROWPACK + i_remain] = V_T(0); - } - } - } - } - } - - attn_fwd_args_t args{ - /* .Q = */ src_q.data(), - /* .K = */ src_k.data(), - /* .V = */ src_v.data(), - /* .dst = */ ref.data(), - /* .Q_sc = */ init_scale_val, - /* .K_sc = */ init_scale_val, - /* .V_sc = */ init_scale_val, - /* .dst_sc = */ init_scale_val, - /* .tmp = */ tmp.data(), - /* .QK_scale = */ 1.f / sqrtf(static_cast(head_size)), - /* .attn_flags = */ flags, - /* .batch_size = */ batch_size, - /* .head_num = */ head_num, - /* .heads_kv = */ heads_kv, - /* .head_size = */ head_size, - /* .sl_q = */ sl_q, - /* .sl_kv = */ sl_kv, - /* .Q_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .K_layout = */ kv_layout, - /* .V_layout = */ kv_layout, - /* .dst_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .step_q_bs = */ sl_q * head_num * head_size, - /* .step_q_head_num = */ head_size, - /* .step_q_sl = */ head_num * head_size, - /* .step_k_bs = */ sl_kv * heads_kv * head_size, - /* .step_k_head_num = */ k_trans ? head_size * sl_kv : head_size, - /* .step_k_sl = */ k_trans ? 1 : heads_kv * head_size, - /* .step_k_head_size = */ k_trans ? sl_kv : 1, - /* .step_v_bs = */ sl_kv * heads_kv * head_size, - /* .step_v_head_num = */ head_size, - /* .step_v_sl = */ heads_kv * head_size, - /* .step_v_head_size = */ 1, - /* .step_dst_bs = */ sl_q * head_num * head_size, - /* .step_dst_head_num = */ head_size, - /* .step_dst_sl = */ head_num * head_size, - }; - if (kv_layout != ATTN_FWD_LAYOUT_PLAIN) { - args.step_k_bs = heads_kv * k_rows_pad * k_cols_pad; - args.step_k_head_num = k_rows_pad * k_cols_pad; - args.step_k_sl = k_rows_pad; - args.step_k_head_size = NTILE; - args.step_v_bs = heads_kv * v_rows_pad * v_cols_pad; - args.step_v_head_num = v_rows_pad * v_cols_pad; - args.step_v_sl = NTILE; - args.step_v_head_size = v_rows_pad; - } - - bestla_fusion_attn_forward_ref(args); - - args.dst = dst.data(); - bestla_fusion_attn_forward(args); - - // Check result - return compare_data(dst.data(), ref.data(), dst.size(), eps); - } - - template - bool test_reorder_pipe(const attn_shape_t& s, int sl_kv_max, ne_attn_flags_t flags) { - const auto batch_size = s.batch_size; - const auto head_num = s.head_num; - const auto heads_kv = s.heads_kv; - const auto head_size = s.head_size; - const auto sl_q = s.sl_q; - const auto sl_kv = s.sl_kv; - assert(("head_num must be a multiple of heads_kv!", head_num % heads_kv == 0)); - - const auto is_causal = flags & NE_ATTN_FLAG_IS_CAUSAL ? "maksed" : "unmask"; - const auto is_alibi8 = flags & NE_ATTN_FLAG_IS_ALIBI8 ? "alibi8" : ""; - const auto prefer_fp32 = flags & NE_ATTN_FLAG_PREFER_FP32 ? "FP32" : ""; - printf("\ntest_case: %s\t", __PRETTY_FUNCTION__); - printf("bs_%d hn_%d hkv_%d hs_%d sl_q_%d sk_kv_%d %s %s %s\n", batch_size, head_num, heads_kv, head_size, sl_q, - sl_kv, is_causal, is_alibi8, prefer_fp32); - - assert(sl_kv_max >= sl_kv); - - kv_shape_t kv_shape = { - /* .heads_kv */ static_cast(heads_kv), - /* .head_size */ static_cast(head_size), - /* .sl_kv_max */ static_cast(sl_kv_max), - }; - kv_cache_info_t kv_cache_info; - bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); - assert(kv_cache_info.k_layout >= kv_cache_info.v_layout); - const auto kv_layout = kv_cache_info.k_layout; - const auto NTILE = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 48 - : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 48 - : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 24 - : 0; - const auto ROWPACK = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 4 - : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 2 - : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 1 - : 0; - const auto ROWPAD = ROWPACK > 1 ? ROWPACK * 16 : 1; - const auto k_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, ROWPAD) : head_size; - const auto k_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, NTILE) : sl_kv; - const auto v_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, ROWPAD) : sl_kv; - const auto v_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, NTILE) : head_size; - - std::vector src_q(batch_size * head_num * sl_q * head_size); - std::vector src_k(batch_size * heads_kv * sl_kv * head_size); - std::vector src_v(batch_size * heads_kv * sl_kv * head_size); - std::vector k_cache(batch_size * kv_cache_info.k_bytes); - std::vector v_cache(batch_size * kv_cache_info.v_bytes); - std::vector dst(batch_size * head_num * sl_q * head_size); - std::vector ref(batch_size * head_num * sl_q * head_size); // reference result - std::vector tmp(bestla_fusion_attn_workspace_size(&s)); - - // init vector - static std::mt19937 rng(1); - std::uniform_int_distribution<> dist; - init_vector(&src_q, init_min_val, init_max_val, dist(rng)); - init_vector(&src_k, init_min_val, init_max_val, dist(rng)); - init_vector(&src_v, init_min_val, init_max_val, dist(rng)); - - // undefined values - init_vector(&k_cache, INT8_MIN, INT8_MAX, dist(rng)); - init_vector(&v_cache, INT8_MIN, INT8_MAX, dist(rng)); - - int step_src_k_bs = sl_kv * heads_kv * head_size; - int step_src_k_head_num = head_size; - int step_src_k_sl = heads_kv * head_size; - int step_src_k_head_size = 1; - int step_src_v_bs = sl_kv * heads_kv * head_size; - int step_src_v_head_num = head_size; - int step_src_v_sl = heads_kv * head_size; - int step_src_v_head_size = 1; - attn_fwd_args_t ref_args{ - /* .Q = */ src_q.data(), - /* .K = */ src_k.data(), - /* .V = */ src_v.data(), - /* .dst = */ ref.data(), - /* .Q_sc = */ init_scale_val, - /* .K_sc = */ init_scale_val, - /* .V_sc = */ init_scale_val, - /* .dst_sc = */ init_scale_val, - /* .tmp = */ tmp.data(), - /* .QK_scale = */ 1.f / sqrtf(static_cast(head_size)), - /* .attn_flags = */ flags, - /* .batch_size = */ batch_size, - /* .head_num = */ head_num, - /* .heads_kv = */ heads_kv, - /* .head_size = */ head_size, - /* .sl_q = */ sl_q, - /* .sl_kv = */ sl_kv, - /* .Q_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .K_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .V_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .dst_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .step_q_bs = */ sl_q * head_num * head_size, - /* .step_q_head_num = */ head_size, - /* .step_q_sl = */ head_num * head_size, - - /* .step_k_bs = */ step_src_k_bs, - /* .step_k_head_num = */ step_src_k_head_num, - /* .step_k_sl = */ step_src_k_sl, - /* .step_k_head_size = */ step_src_k_head_size, - /* .step_v_bs = */ step_src_v_bs, - /* .step_v_head_num = */ step_src_v_head_num, - /* .step_v_sl = */ step_src_v_sl, - /* .step_v_head_size = */ step_src_v_head_size, - - /* .step_dst_bs = */ sl_q * head_num * head_size, - /* .step_dst_head_num = */ head_size, - /* .step_dst_sl = */ head_num * head_size, - }; - bestla_fusion_attn_forward_ref(ref_args); - - if (std::is_same, std::tuple>::value) { - assert(kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 || kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1); - // for testing, first reorder sl_kv - 1 and than concat the last 1 line - const auto seq_size_first = sl_kv - 1; - const auto seq_size_next = 1; - bestla_fusion_attn_fp32_update_kv_args_t update_k_args = { - /* .src = */ src_k.data(), - /* .cache = */ k_cache.data(), - /* .batch_size = */ batch_size, - /* .heads_kv = */ heads_kv, - /* .head_size = */ head_size, - /* .seq_off = */ 0, - /* .seq_size = */ seq_size_first, - /* .seq_max = */ sl_kv_max, - /* .step_bs = */ step_src_k_bs, - /* .step_head_num = */ step_src_k_head_num, - /* .step_seq = */ step_src_k_sl, - /* .step_head_size = */ step_src_k_head_size, - }; - bestla_reordered_attn_fp32_update_k(&update_k_args); - - bestla_fusion_attn_fp32_update_kv_args_t update_v_args = { - /* .src = */ src_v.data(), - /* .cache = */ v_cache.data(), - /* .batch_size = */ batch_size, - /* .heads_kv = */ heads_kv, - /* .head_size = */ head_size, - /* .seq_off = */ 0, - /* .seq_size = */ seq_size_first, - /* .seq_max = */ sl_kv_max, - /* .step_bs = */ step_src_v_bs, - /* .step_head_num = */ step_src_v_head_num, - /* .step_seq = */ step_src_v_sl, - /* .step_head_size = */ step_src_v_head_size, - }; - bestla_reordered_attn_fp32_update_v(&update_v_args); - - update_k_args.seq_off = seq_size_first; - update_k_args.seq_size = seq_size_next; - update_k_args.src = src_k.data() + seq_size_first * step_src_k_sl; - bestla_reordered_attn_fp32_update_k(&update_k_args); - - update_v_args.seq_off = seq_size_first; - update_v_args.seq_size = seq_size_next; - update_v_args.src = src_v.data() + seq_size_first * step_src_v_sl; - bestla_reordered_attn_fp32_update_v(&update_v_args); - - bestla_reordered_attn_fp32_fp32_fwd_args_t kern_args{ - /* .Q = */ reinterpret_cast(src_q.data()), - /* .K = */ k_cache.data(), - /* .V = */ v_cache.data(), - /* .dst = */ reinterpret_cast(dst.data()), - /* .Q_sc = */ init_scale_val, - /* .K_sc = */ init_scale_val, - /* .V_sc = */ init_scale_val, - /* .dst_sc = */ init_scale_val, - /* .tmp = */ tmp.data(), - /* .QK_scale = */ 1.f / sqrtf(static_cast(head_size)), - /* .attn_flags = */ flags, - /* .batch_size = */ batch_size, - /* .head_num = */ head_num, - /* .heads_kv = */ heads_kv, - /* .head_size = */ head_size, - /* .sl_q = */ sl_q, - /* .sl_kv = */ sl_kv, - /* .Q_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .K_layout = */ kv_layout, - /* .V_layout = */ kv_layout, - /* .dst_layout = */ ATTN_FWD_LAYOUT_PLAIN, - /* .step_q_bs = */ sl_q * head_num * head_size, - /* .step_q_head_num = */ head_size, - /* .step_q_sl = */ head_num * head_size, - - /* .stride_k_bs = */ static_cast(kv_cache_info.k_bytes), - /* .stride_k_head_num = */ kv_cache_info.stride_k_head_num, - /* .stride_k_sl = */ kv_cache_info.stride_k_sl, - /* .stride_k_head_size = */ kv_cache_info.stride_k_head_size, - /* .stride_v_bs = */ static_cast(kv_cache_info.v_bytes), - /* .stride_v_head_num = */ kv_cache_info.stride_v_head_num, - /* .stride_v_sl = */ kv_cache_info.stride_v_sl, - /* .stride_v_head_size = */ kv_cache_info.stride_v_head_size, - - /* .step_dst_bs = */ sl_q * head_num * head_size, - /* .step_dst_head_num = */ head_size, - /* .step_dst_sl = */ head_num * head_size, - }; - bestla_reordered_attn_fp32_forward(&kern_args); - } - - // Check result - return compare_data(dst.data(), ref.data(), dst.size(), 1e-2f); - } -}; -static const TestMhaDese inst_; - -} // namespace - -int main() { - printf("NS_TESTS: mha_dense "); - printf(ret_ok ? "OK\n" : "FAILED\n"); - return ret_ok ? 0 : -1; -} -#endif diff --git a/neural_speed/core/layers/mha_dense_tests.cpp b/neural_speed/core/layers/mha_dense_tests.cpp new file mode 100644 index 000000000..155423268 --- /dev/null +++ b/neural_speed/core/layers/mha_dense_tests.cpp @@ -0,0 +1,499 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "layers/mha_dense.h" +#include "layers/mha_dense_wrapper.h" +#include "layers/ne_test_layers_utils.hpp" + +#ifndef NS_TESTS +static_assert(false, "Only compile this source file for testing!"); +#endif + +using namespace ne_bestla::custom::mha; // NOLINT + +#define CheckISA(ISA) \ + (bestla::device::CpuDevice::getInstance()->ISA() || (printf("Wrong Device ISA: " #ISA "\n"), false)) + +namespace { +bool ret_ok = true; + +class test_mha_dese_t { + public: + test_mha_dese_t() { + printf("Test suit: %s\n", __FUNCTION__); + GetCPUDevice(); + static const int max_threads = std::thread::hardware_concurrency(); + ne_threading::get()->set_threads(std::min(_cd->getThreads(), max_threads)); + +#if CompileFP16() + if (CheckISA(AMX_BF16)) { + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE); + ret_ok &= test_case({1, 1, 1, 32, 63, 63}, NE_ATTN_FLAG_NONE); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE); + ret_ok &= test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL); + + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({1, 1, 1, 256, 63, 63}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL, true); + + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({1, 1, 1, 256, 63, 63}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE, true); + ret_ok &= test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL, true); + } +#endif + + if (CheckISA(AMX_BF16)) { + const auto BA48b4a = ATTN_FWD_LAYOUT_NTILE48_ROWPACK4; + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, NE_ATTN_FLAG_NONE, false, BA48b4a); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, NE_ATTN_FLAG_NONE, false, BA48b4a); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, NE_ATTN_FLAG_NONE, false, BA48b4a); + ret_ok &= test_case({1, 1, 1, 256, 63, 63}, NE_ATTN_FLAG_NONE, false, BA48b4a); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, NE_ATTN_FLAG_NONE, false, BA48b4a); + ret_ok &= + test_case({1, 1, 1, 64, 64, 64}, NE_ATTN_FLAG_IS_CAUSAL, false, BA48b4a); + } + + if (CheckISA(AMX_BF16)) { + const auto BA48b2a = ATTN_FWD_LAYOUT_NTILE48_ROWPACK2; + int flags = NE_ATTN_FLAG_NONE; + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({1, 1, 1, 256, 63, 63}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, flags, false, BA48b2a, 1e-3f); + + flags |= NE_ATTN_FLAG_IS_CAUSAL; + ret_ok &= test_case({1, 1, 1, 64, 64, 64}, flags, false, BA48b2a, 1e-3f); + } + + if (CheckISA(AVX512F)) { // PREFER_FP32 + const auto BA48b2a = ATTN_FWD_LAYOUT_NTILE48_ROWPACK2; + int flags = NE_ATTN_FLAG_PREFER_FP32; + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({1, 1, 1, 256, 63, 63}, flags, false, BA48b2a, 1e-3f); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, flags, false, BA48b2a, 1e-3f); + + flags |= NE_ATTN_FLAG_IS_CAUSAL; + ret_ok &= test_case({1, 1, 1, 64, 64, 64}, flags, false, BA48b2a, 1e-3f); + } + if (CheckISA(AVX2)) { // avx2 + const auto Ba24b = ATTN_FWD_LAYOUT_NTILE24_ROWPACK1; + int flags = NE_ATTN_FLAG_PREFER_FP32; + ret_ok &= test_case({1, 1, 1, 32, 128, 64}, flags, false, Ba24b, 1e-3f); + ret_ok &= test_case({2, 5, 5, 32, 64, 128}, flags, false, Ba24b, 1e-3f); + ret_ok &= test_case({2, 5, 5, 80, 128, 77}, flags, false, Ba24b, 1e-3f); + ret_ok &= test_case({1, 1, 1, 256, 63, 63}, flags, false, Ba24b, 1e-3f); + ret_ok &= test_case({3, 4, 4, 256, 1, 384}, flags, false, Ba24b, 1e-3f); + + flags |= NE_ATTN_FLAG_IS_CAUSAL; + ret_ok &= test_case({1, 1, 1, 64, 64, 64}, flags, false, Ba24b, 1e-3f); + } + + { // amxbf16 => avx2 fallback + int flags = NE_ATTN_FLAG_NONE; + ret_ok &= test_reorder_pipe({1, 1, 1, 32, 128, 64}, 64, flags); + ret_ok &= test_reorder_pipe({2, 5, 5, 32, 64, 128}, 256, flags); + ret_ok &= test_reorder_pipe({2, 5, 5, 80, 128, 77}, 256, flags); + ret_ok &= test_reorder_pipe({2, 5, 1, 80, 128, 77}, 256, flags); + ret_ok &= test_reorder_pipe({1, 1, 1, 256, 63, 63}, 256, flags); + ret_ok &= test_reorder_pipe({3, 4, 4, 256, 1, 384}, 384, flags); + ret_ok &= test_reorder_pipe({3, 4, 2, 256, 1, 384}, 384, flags); + flags |= NE_ATTN_FLAG_IS_CAUSAL; + ret_ok &= test_reorder_pipe({1, 1, 1, 64, 64, 64}, 128, flags); + flags |= NE_ATTN_FLAG_IS_ALIBI8; + ret_ok &= test_reorder_pipe({1, 8, 8, 64, 64, 64}, 128, flags); + } + printf("Test suit done: %s\n", __FUNCTION__); + } + + template + static constexpr float init_min_val = std::is_same::value ? -127.f + : std::is_same::value ? 0.f + : -1.f; + template + static constexpr float init_max_val = std::is_same::value ? 127.f + : std::is_same::value ? 255.f + : 1.f; + template + static constexpr float init_scale_val = 1.f / init_max_val; + +#ifdef _MSC_VER +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + + template + bool test_case(const attn_shape_t& s, ne_attn_flags_t flags, bool k_trans = false, + ATTN_FWD_LAYOUT kv_layout = ATTN_FWD_LAYOUT_PLAIN, float eps = 1e-2f) { + assert(kv_layout == ATTN_FWD_LAYOUT_PLAIN || !k_trans); + const auto batch_size = s.batch_size; + const auto head_num = s.head_num; + const auto heads_kv = s.heads_kv; + const auto head_size = s.head_size; + const auto sl_q = s.sl_q; + const auto sl_kv = s.sl_kv; + assert(("GQA not supported!", s.head_num == s.heads_kv)); + + const auto is_causal = flags & NE_ATTN_FLAG_IS_CAUSAL ? "maksed" : "unmask"; + const auto is_alibi8 = flags & NE_ATTN_FLAG_IS_ALIBI8 ? "alibi8" : ""; + const auto prefer_fp32 = flags & NE_ATTN_FLAG_PREFER_FP32 ? "FP32" : ""; + printf("\ntest_case: %s\t", __PRETTY_FUNCTION__); + printf("bs_%d hn_%d hkv_%d hs_%d sl_q_%d sk_kv_%d %s %s %s\n", batch_size, head_num, heads_kv, head_size, sl_q, + sl_kv, is_causal, is_alibi8, prefer_fp32); + + const auto NTILE = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 48 + : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 48 + : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 24 + : 0; + const auto ROWPACK = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 4 + : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 2 + : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 1 + : 0; + const auto ROWPAD = ROWPACK > 1 ? ROWPACK * 16 : 1; + const auto k_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, ROWPAD) : head_size; + const auto k_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, NTILE) : sl_kv; + const auto v_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, ROWPAD) : sl_kv; + const auto v_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, NTILE) : head_size; + + std::vector src_q(batch_size * head_num * sl_q * head_size); + std::vector src_k(batch_size * heads_kv * k_rows_pad * k_cols_pad); + std::vector src_v(batch_size * heads_kv * v_rows_pad * v_cols_pad); + std::vector dst(batch_size * head_num * sl_q * head_size); + std::vector ref(batch_size * head_num * sl_q * head_size); // reference result + std::vector tmp(bestla_fusion_attn_workspace_size(&s)); + + // init vector + static std::mt19937 rng(1); + std::uniform_int_distribution<> dist; + init_vector(&src_q, init_min_val, init_max_val, dist(rng)); + init_vector(&src_k, init_min_val, init_max_val, dist(rng)); + init_vector(&src_v, init_min_val, init_max_val, dist(rng)); + + // pad0 for padded layouts + if (kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 || kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 || + kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1) { +#pragma omp parallel for collapse(2) + for (int ibs = 0; ibs < batch_size; ++ibs) { + for (int ihn = 0; ihn < heads_kv; ++ihn) { + // K + const auto k_off = (ibs * heads_kv + ihn) * k_rows_pad * k_cols_pad; + for (int i = 0; i < k_rows_pad; ++i) { + for (int j = 0; j < k_cols_pad; ++j) { + if (i < head_size && j < sl_kv) continue; + + const auto j_remain = j % NTILE; + const auto j_block = j - j_remain; + const auto i_remain = i % ROWPACK; + const auto i_block = i - i_remain; + src_k[k_off + j_block * k_rows_pad + i_block * NTILE + j_remain * ROWPACK + i_remain] = K_T(0); + } + } + // V + const auto v_off = (ibs * heads_kv + ihn) * v_rows_pad * v_cols_pad; + for (int i = 0; i < v_rows_pad; ++i) { + for (int j = 0; j < v_cols_pad; ++j) { + if (i < sl_kv && j < head_size) continue; + + const auto j_remain = j % NTILE; + const auto j_block = j - j_remain; + const auto i_remain = i % ROWPACK; + const auto i_block = i - i_remain; + src_v[v_off + j_block * v_rows_pad + i_block * NTILE + j_remain * ROWPACK + i_remain] = V_T(0); + } + } + } + } + } + + attn_fwd_args_t args{ + /* .Q = */ src_q.data(), + /* .K = */ src_k.data(), + /* .V = */ src_v.data(), + /* .dst = */ ref.data(), + /* .Q_sc = */ init_scale_val, + /* .K_sc = */ init_scale_val, + /* .V_sc = */ init_scale_val, + /* .dst_sc = */ init_scale_val, + /* .tmp = */ tmp.data(), + /* .QK_scale = */ 1.f / sqrtf(static_cast(head_size)), + /* .attn_flags = */ flags, + /* .batch_size = */ batch_size, + /* .head_num = */ head_num, + /* .heads_kv = */ heads_kv, + /* .head_size = */ head_size, + /* .sl_q = */ sl_q, + /* .sl_kv = */ sl_kv, + /* .Q_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .K_layout = */ kv_layout, + /* .V_layout = */ kv_layout, + /* .dst_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .step_q_bs = */ sl_q * head_num * head_size, + /* .step_q_head_num = */ head_size, + /* .step_q_sl = */ head_num * head_size, + /* .step_k_bs = */ sl_kv * heads_kv * head_size, + /* .step_k_head_num = */ k_trans ? head_size * sl_kv : head_size, + /* .step_k_sl = */ k_trans ? 1 : heads_kv * head_size, + /* .step_k_head_size = */ k_trans ? sl_kv : 1, + /* .step_v_bs = */ sl_kv * heads_kv * head_size, + /* .step_v_head_num = */ head_size, + /* .step_v_sl = */ heads_kv * head_size, + /* .step_v_head_size = */ 1, + /* .step_dst_bs = */ sl_q * head_num * head_size, + /* .step_dst_head_num = */ head_size, + /* .step_dst_sl = */ head_num * head_size, + }; + if (kv_layout != ATTN_FWD_LAYOUT_PLAIN) { + args.step_k_bs = heads_kv * k_rows_pad * k_cols_pad; + args.step_k_head_num = k_rows_pad * k_cols_pad; + args.step_k_sl = k_rows_pad; + args.step_k_head_size = NTILE; + args.step_v_bs = heads_kv * v_rows_pad * v_cols_pad; + args.step_v_head_num = v_rows_pad * v_cols_pad; + args.step_v_sl = NTILE; + args.step_v_head_size = v_rows_pad; + } + + bestla_fusion_attn_forward_ref(args); + + args.dst = dst.data(); + bestla_fusion_attn_forward(args); + + // Check result + return compare_data(dst.data(), ref.data(), dst.size(), eps); + } + + template + bool test_reorder_pipe(const attn_shape_t& s, int sl_kv_max, ne_attn_flags_t flags) { + const auto batch_size = s.batch_size; + const auto head_num = s.head_num; + const auto heads_kv = s.heads_kv; + const auto head_size = s.head_size; + const auto sl_q = s.sl_q; + const auto sl_kv = s.sl_kv; + assert(("head_num must be a multiple of heads_kv!", head_num % heads_kv == 0)); + + const auto is_causal = flags & NE_ATTN_FLAG_IS_CAUSAL ? "maksed" : "unmask"; + const auto is_alibi8 = flags & NE_ATTN_FLAG_IS_ALIBI8 ? "alibi8" : ""; + const auto prefer_fp32 = flags & NE_ATTN_FLAG_PREFER_FP32 ? "FP32" : ""; + printf("\ntest_case: %s\t", __PRETTY_FUNCTION__); + printf("bs_%d hn_%d hkv_%d hs_%d sl_q_%d sk_kv_%d %s %s %s\n", batch_size, head_num, heads_kv, head_size, sl_q, + sl_kv, is_causal, is_alibi8, prefer_fp32); + + assert(sl_kv_max >= sl_kv); + + kv_shape_t kv_shape = { + /* .heads_kv */ static_cast(heads_kv), + /* .head_size */ static_cast(head_size), + /* .sl_kv_max */ static_cast(sl_kv_max), + }; + kv_cache_info_t kv_cache_info; + bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); + assert(kv_cache_info.k_layout >= kv_cache_info.v_layout); + const auto kv_layout = kv_cache_info.k_layout; + const auto NTILE = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 48 + : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 48 + : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 24 + : 0; + const auto ROWPACK = kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 4 + : kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 2 + : kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1 ? 1 + : 0; + const auto ROWPAD = ROWPACK > 1 ? ROWPACK * 16 : 1; + const auto k_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, ROWPAD) : head_size; + const auto k_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, NTILE) : sl_kv; + const auto v_rows_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(sl_kv, ROWPAD) : sl_kv; + const auto v_cols_pad = kv_layout != ATTN_FWD_LAYOUT_PLAIN ? padto(head_size, NTILE) : head_size; + + std::vector src_q(batch_size * head_num * sl_q * head_size); + std::vector src_k(batch_size * heads_kv * sl_kv * head_size); + std::vector src_v(batch_size * heads_kv * sl_kv * head_size); + std::vector k_cache(batch_size * kv_cache_info.k_bytes); + std::vector v_cache(batch_size * kv_cache_info.v_bytes); + std::vector dst(batch_size * head_num * sl_q * head_size); + std::vector ref(batch_size * head_num * sl_q * head_size); // reference result + std::vector tmp(bestla_fusion_attn_workspace_size(&s)); + + // init vector + static std::mt19937 rng(1); + std::uniform_int_distribution<> dist; + init_vector(&src_q, init_min_val, init_max_val, dist(rng)); + init_vector(&src_k, init_min_val, init_max_val, dist(rng)); + init_vector(&src_v, init_min_val, init_max_val, dist(rng)); + + // undefined values + init_vector(&k_cache, INT8_MIN, INT8_MAX, dist(rng)); + init_vector(&v_cache, INT8_MIN, INT8_MAX, dist(rng)); + + int step_src_k_bs = sl_kv * heads_kv * head_size; + int step_src_k_head_num = head_size; + int step_src_k_sl = heads_kv * head_size; + int step_src_k_head_size = 1; + int step_src_v_bs = sl_kv * heads_kv * head_size; + int step_src_v_head_num = head_size; + int step_src_v_sl = heads_kv * head_size; + int step_src_v_head_size = 1; + attn_fwd_args_t ref_args{ + /* .Q = */ src_q.data(), + /* .K = */ src_k.data(), + /* .V = */ src_v.data(), + /* .dst = */ ref.data(), + /* .Q_sc = */ init_scale_val, + /* .K_sc = */ init_scale_val, + /* .V_sc = */ init_scale_val, + /* .dst_sc = */ init_scale_val, + /* .tmp = */ tmp.data(), + /* .QK_scale = */ 1.f / sqrtf(static_cast(head_size)), + /* .attn_flags = */ flags, + /* .batch_size = */ batch_size, + /* .head_num = */ head_num, + /* .heads_kv = */ heads_kv, + /* .head_size = */ head_size, + /* .sl_q = */ sl_q, + /* .sl_kv = */ sl_kv, + /* .Q_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .K_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .V_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .dst_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .step_q_bs = */ sl_q * head_num * head_size, + /* .step_q_head_num = */ head_size, + /* .step_q_sl = */ head_num * head_size, + + /* .step_k_bs = */ step_src_k_bs, + /* .step_k_head_num = */ step_src_k_head_num, + /* .step_k_sl = */ step_src_k_sl, + /* .step_k_head_size = */ step_src_k_head_size, + /* .step_v_bs = */ step_src_v_bs, + /* .step_v_head_num = */ step_src_v_head_num, + /* .step_v_sl = */ step_src_v_sl, + /* .step_v_head_size = */ step_src_v_head_size, + + /* .step_dst_bs = */ sl_q * head_num * head_size, + /* .step_dst_head_num = */ head_size, + /* .step_dst_sl = */ head_num * head_size, + }; + bestla_fusion_attn_forward_ref(ref_args); + + if (std::is_same, std::tuple>::value) { + assert(kv_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 || kv_layout == ATTN_FWD_LAYOUT_NTILE24_ROWPACK1); + // for testing, first reorder sl_kv - 1 and than concat the last 1 line + const auto seq_size_first = sl_kv - 1; + const auto seq_size_next = 1; + bestla_fusion_attn_fp32_update_kv_args_t update_k_args = { + /* .src = */ src_k.data(), + /* .cache = */ k_cache.data(), + /* .batch_size = */ batch_size, + /* .heads_kv = */ heads_kv, + /* .head_size = */ head_size, + /* .seq_off = */ 0, + /* .seq_size = */ seq_size_first, + /* .seq_max = */ sl_kv_max, + /* .step_bs = */ step_src_k_bs, + /* .step_head_num = */ step_src_k_head_num, + /* .step_seq = */ step_src_k_sl, + /* .step_head_size = */ step_src_k_head_size, + }; + bestla_reordered_attn_fp32_update_k(&update_k_args); + + bestla_fusion_attn_fp32_update_kv_args_t update_v_args = { + /* .src = */ src_v.data(), + /* .cache = */ v_cache.data(), + /* .batch_size = */ batch_size, + /* .heads_kv = */ heads_kv, + /* .head_size = */ head_size, + /* .seq_off = */ 0, + /* .seq_size = */ seq_size_first, + /* .seq_max = */ sl_kv_max, + /* .step_bs = */ step_src_v_bs, + /* .step_head_num = */ step_src_v_head_num, + /* .step_seq = */ step_src_v_sl, + /* .step_head_size = */ step_src_v_head_size, + }; + bestla_reordered_attn_fp32_update_v(&update_v_args); + + update_k_args.seq_off = seq_size_first; + update_k_args.seq_size = seq_size_next; + update_k_args.src = src_k.data() + seq_size_first * step_src_k_sl; + bestla_reordered_attn_fp32_update_k(&update_k_args); + + update_v_args.seq_off = seq_size_first; + update_v_args.seq_size = seq_size_next; + update_v_args.src = src_v.data() + seq_size_first * step_src_v_sl; + bestla_reordered_attn_fp32_update_v(&update_v_args); + + bestla_reordered_attn_fp32_fp32_fwd_args_t kern_args{ + /* .Q = */ reinterpret_cast(src_q.data()), + /* .K = */ k_cache.data(), + /* .V = */ v_cache.data(), + /* .dst = */ reinterpret_cast(dst.data()), + /* .Q_sc = */ init_scale_val, + /* .K_sc = */ init_scale_val, + /* .V_sc = */ init_scale_val, + /* .dst_sc = */ init_scale_val, + /* .tmp = */ tmp.data(), + /* .QK_scale = */ 1.f / sqrtf(static_cast(head_size)), + /* .attn_flags = */ flags, + /* .batch_size = */ batch_size, + /* .head_num = */ head_num, + /* .heads_kv = */ heads_kv, + /* .head_size = */ head_size, + /* .sl_q = */ sl_q, + /* .sl_kv = */ sl_kv, + /* .Q_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .K_layout = */ kv_layout, + /* .V_layout = */ kv_layout, + /* .dst_layout = */ ATTN_FWD_LAYOUT_PLAIN, + /* .step_q_bs = */ sl_q * head_num * head_size, + /* .step_q_head_num = */ head_size, + /* .step_q_sl = */ head_num * head_size, + + /* .stride_k_bs = */ static_cast(kv_cache_info.k_bytes), + /* .stride_k_head_num = */ kv_cache_info.stride_k_head_num, + /* .stride_k_sl = */ kv_cache_info.stride_k_sl, + /* .stride_k_head_size = */ kv_cache_info.stride_k_head_size, + /* .stride_v_bs = */ static_cast(kv_cache_info.v_bytes), + /* .stride_v_head_num = */ kv_cache_info.stride_v_head_num, + /* .stride_v_sl = */ kv_cache_info.stride_v_sl, + /* .stride_v_head_size = */ kv_cache_info.stride_v_head_size, + + /* .step_dst_bs = */ sl_q * head_num * head_size, + /* .step_dst_head_num = */ head_size, + /* .step_dst_sl = */ head_num * head_size, + }; + bestla_reordered_attn_fp32_forward(&kern_args); + } + + // Check result + return compare_data(dst.data(), ref.data(), dst.size(), 1e-2f); + } +}; +const test_mha_dese_t inst_; + +} // namespace + +int main() { + printf("NS_TESTS: mha_dense "); + printf(ret_ok ? "OK\n" : "FAILED\n"); + return ret_ok ? 0 : -1; +} diff --git a/neural_speed/core/layers/mha_dense_wrapper.h b/neural_speed/core/layers/mha_dense_wrapper.h index 174fc39e4..be586a305 100644 --- a/neural_speed/core/layers/mha_dense_wrapper.h +++ b/neural_speed/core/layers/mha_dense_wrapper.h @@ -1,3 +1,19 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef NE_CORE_GRAPH_MHA_DENSE_WRAPPER_H +#define NE_CORE_GRAPH_MHA_DENSE_WRAPPER_H + #include #include @@ -17,6 +33,13 @@ #include "core/data_types.h" #include "layers/bestla_common.hpp" +#ifdef NS_TESTS +#include +#include + +#include "layers/ne_test_layers_utils.hpp" +#endif + #define MHA_2ND_EXP 1 constexpr bool MHA_PREFER_AVX512FP16 = true; @@ -1762,14 +1785,14 @@ class mha_stable_interface_t { }; template -void bestla_fusion_attn_forward(const attn_fwd_args_t& params) = delete; +inline void bestla_fusion_attn_forward(const attn_fwd_args_t& params) = delete; template using WeightPackBatchBf16Bf16NonTr = weight_pack_batch_bf16_non_tr_t; template using WeightPackBatchBf16Bf16Trans = weight_pack_batch_bf16_trans_t; template <> -void bestla_fusion_attn_forward(const attn_fwd_args_t& p) { +inline void bestla_fusion_attn_forward(const attn_fwd_args_t& p) { using GemmKernelBF16ExpSum = mha::launcher_base_off_t< // BTLA_ISA::AMX_BF16, // gemm::HCoreRowNAmxbf16<64, 16>, // @@ -1793,7 +1816,8 @@ using WeightPackBatchFp16Bf16NonTr = weight_pack_batch_bf16_non_tr_t using WeightPackBatchFp16Bf16Trans = weight_pack_batch_bf16_trans_t; template <> -void bestla_fusion_attn_forward(const attn_fwd_args_t& params) { +inline void bestla_fusion_attn_forward( + const attn_fwd_args_t& params) { GetCPUDevice(); const auto pth = ne_threading::get(); if (MHA_PREFER_AVX512FP16 && _cd->AVX512_FP16() && params.step_k_sl == 1) { @@ -1872,7 +1896,7 @@ void bestla_fusion_attn_forward(const attn_fwd_args_t< } template <> -void bestla_fusion_attn_forward(const attn_fwd_args_t& params) { +inline void bestla_fusion_attn_forward(const attn_fwd_args_t& params) { GetCPUDevice(); const auto pth = ne_threading::get(); if (_cd->AMX_BF16()) { @@ -1897,7 +1921,7 @@ void bestla_fusion_attn_forward(const attn_fwd_args_t -void bestla_fusion_attn_forward( +inline void bestla_fusion_attn_forward( const attn_fwd_args_t& params) { GetCPUDevice(); const auto pth = ne_threading::get(); @@ -1940,7 +1964,8 @@ void bestla_fusion_attn_forward( } template <> -void bestla_fusion_attn_forward(const attn_fwd_args_t& params) { +inline void bestla_fusion_attn_forward( + const attn_fwd_args_t& params) { GetCPUDevice(); const auto pth = ne_threading::get(); if (_cd->AVX512F() && (params.attn_flags & NE_ATTN_FLAG_PREFER_FP32) != 0) { @@ -1981,7 +2006,7 @@ void bestla_fusion_attn_forward(const attn_fwd_args_t< } template -void bestla_fusion_attn_forward_ref(const attn_fwd_args_t& p) { +inline void bestla_fusion_attn_forward_ref(const attn_fwd_args_t& p) { const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0; const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; const bool prefer_fp32 = (p.attn_flags & NE_ATTN_FLAG_PREFER_FP32) != 0; @@ -2127,3 +2152,4 @@ void bestla_fusion_attn_forward_ref(const attn_fwd_args_t& } // namespace mha } // namespace custom } // namespace ne_bestla +#endif // NE_CORE_GRAPH_MHA_DENSE_WRAPPER_H diff --git a/neural_speed/core/layers/ne_test_layers_utils.hpp b/neural_speed/core/layers/ne_test_layers_utils.hpp index fd9607968..9b17c42ac 100644 --- a/neural_speed/core/layers/ne_test_layers_utils.hpp +++ b/neural_speed/core/layers/ne_test_layers_utils.hpp @@ -11,13 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#ifndef NE_CORE_GRAPH_NE_TEST_LAYERS_UTILS_H +#define NE_CORE_GRAPH_NE_TEST_LAYERS_UTILS_H +#include #include +#include #include #include -#include #include -#include #include "bestla/bestla_utils.h" @@ -91,3 +93,4 @@ bool compare_data(const T* buf1, const T* buf2, size_t size, float eps = 1e-6) { } return true; } +#endif // NE_CORE_GRAPH_NE_TEST_LAYERS_UTILS_H