diff --git a/lite/kernels/xpu/__xpu__mmdnn_compute.cc b/lite/kernels/xpu/__xpu__mmdnn_compute.cc index 10e8e3c318d..adb338ce80c 100644 --- a/lite/kernels/xpu/__xpu__mmdnn_compute.cc +++ b/lite/kernels/xpu/__xpu__mmdnn_compute.cc @@ -1112,7 +1112,7 @@ void XPUMmdnnBidEmbGrnnAttCompute2::Run() { int table_m = param.emb_tbl->dims()[0]; int embed_dim = param.emb_tbl->dims()[1]; - int r = xdnn::embedding( + int r = xdnn::paddle_embedding( ctx.GetRawContext(), param.emb_tbl->data(), param.id0->data(), @@ -1122,6 +1122,7 @@ void XPUMmdnnBidEmbGrnnAttCompute2::Run() { num, 128000); CHECK_EQ(r, 0); + } class XPUMmdnnBidEmbAttCompute diff --git a/lite/kernels/xpu/lookup_table_compute.cc b/lite/kernels/xpu/lookup_table_compute.cc index efbdc6aa628..a3548ea6972 100644 --- a/lite/kernels/xpu/lookup_table_compute.cc +++ b/lite/kernels/xpu/lookup_table_compute.cc @@ -30,7 +30,7 @@ void LookupTableCompute::Run() { int xm = param.W->dims()[0]; int n = param.W->dims()[1]; - int r = xdnn::embedding( + int r = xdnn::paddle_embedding( ctx.GetRawContext(), /* context */ param.W->template data(), param.Ids->template data(),