Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sycl-exp : Enabled more data types for oneMKL's gemm_batch API #8182

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,7 @@ namespace dpct
b, ldb, beta, c, ldc, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32):
Expand Down Expand Up @@ -2522,7 +2523,6 @@ namespace dpct
batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float):
Expand Down Expand Up @@ -2659,6 +2659,7 @@ namespace dpct
stride_c, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32):
Expand Down Expand Up @@ -2687,7 +2688,6 @@ namespace dpct
beta, c, ldc, stride_c, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float):
Expand Down Expand Up @@ -15026,9 +15026,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];

bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;

SYCL_CHECK(
CHECK_TRY_ERROR(g_sycl_handles[g_main_device] = main_stream));

Expand Down Expand Up @@ -15059,10 +15056,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,

dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
if (no_mixed_dtypes) {
cu_compute_type = dpct::library_data_t::real_half;
cu_data_type = dpct::library_data_t::real_half;
}

// dst strides
size_t nbd2 = dst->nb[2];
Expand All @@ -15076,21 +15069,8 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,

const void * alpha = &alpha_f32;
const void * beta = &beta_f32;
if (no_mixed_dtypes) {
alpha = &alpha_f16;
beta = &beta_f16;
}

// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
// when oneMKL open source supports half, half, float, float: datatypes

dst_t = (char *) dst_ddf;
if (no_mixed_dtypes) {
dst_t = (char *) dst_f16.alloc(ne_dst);

nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand Down Expand Up @@ -15152,11 +15132,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
cu_compute_type)));
}

if (no_mixed_dtypes) {
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
Expand Down
Loading