Skip to content

Commit

Permalink
[SYCL] Update SYCL upscale operation (ggerganov#7321)
Browse files Browse the repository at this point in the history
* Update SYCL upscale operation

* Formatting

* Remove messages
  • Loading branch information
AidanBeltonS authored May 20, 2024
1 parent 26cd423 commit 6bf9b66
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3847,21 +3847,27 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne
}
}

static void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor,
const sycl::nd_item<3> &item_ct1) {
int ne0 = ne00 * scale_factor;
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (nidx >= ne0) {
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
int index = item_ct1.get_local_id(0) +
item_ct1.get_group(0) * item_ct1.get_local_range(0);
if (index >= ne10 * ne11 * ne12 * ne13) {
return;
}
// operation
int i00 = nidx / scale_factor;
int i01 = item_ct1.get_group(1) / scale_factor;
int offset_src = i00 + i01 * ne00 + item_ct1.get_group(0) * nb02;
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = x[offset_src];
int i10 = index % ne10;
int i11 = (index / ne10) % ne11;
int i12 = (index / (ne10 * ne11)) % ne12;
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;

int i00 = i10 / sf0;
int i01 = i11 / sf1;
int i02 = i12 / sf2;
int i03 = i13 / sf3;

dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
}

static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
Expand Down Expand Up @@ -10085,18 +10091,17 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
});
}

static void upscale_f32_sycl(const float *x, float *dst, const int ne00,
const int ne01, const int ne02,
const int scale_factor, dpct::queue_ptr stream) {
int ne0 = (ne00 * scale_factor);
int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks);
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
const float sf2, const float sf3, dpct::queue_ptr stream) {
int dst_size = ne10 * ne11 * ne12 * ne13;
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1);
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
});
}

Expand Down Expand Up @@ -13985,15 +13990,15 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors

#pragma message("TODO: generalize upscale operator")
#pragma message(" https://github.com/ggerganov/ggml/pull/814")
GGML_ASSERT(false && "TODO: generalize upscale operator");

const int scale_factor = dst->op_params[0];
const float sf0 = (float)dst->ne[0]/src0->ne[0];
const float sf1 = (float)dst->ne[1]/src0->ne[1];
const float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3];

upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
main_stream);

(void) src1;
(void) dst;
Expand Down

0 comments on commit 6bf9b66

Please sign in to comment.