diff --git a/paddle/phi/kernels/gpu/erfinv_grad_kernel.cu b/paddle/phi/kernels/gpu/erfinv_grad_kernel.cu index 034788502b043..055caf66b1e14 100644 --- a/paddle/phi/kernels/gpu/erfinv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/erfinv_grad_kernel.cu @@ -22,5 +22,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - erfinv_grad, GPU, ALL_LAYOUT, phi::ErfinvGradKernel, float, double) {} +PD_REGISTER_KERNEL(erfinv_grad, + GPU, + ALL_LAYOUT, + phi::ErfinvGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/erfinv_kernel.cu b/paddle/phi/kernels/gpu/erfinv_kernel.cu index 087eb46c33fdc..87a8712f7577a 100644 --- a/paddle/phi/kernels/gpu/erfinv_kernel.cu +++ b/paddle/phi/kernels/gpu/erfinv_kernel.cu @@ -23,7 +23,21 @@ template struct ErfinvFunctor { HOSTDEVICE inline T operator()(const T x) const { return erfinv(x); } }; +template <> +struct ErfinvFunctor { + HOSTDEVICE inline float16 operator()(const float16 x) const { + auto x_ = static_cast(x); + return static_cast(erfinv(x_)); + } +}; +template <> +struct ErfinvFunctor { + HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const { + auto x_ = static_cast(x); + return static_cast(erfinv(x_)); + } +}; template void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { ctx.template Alloc(out); @@ -34,4 +48,11 @@ void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { } // namespace phi -PD_REGISTER_KERNEL(erfinv, GPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {} +PD_REGISTER_KERNEL(erfinv, + GPU, + ALL_LAYOUT, + phi::ErfinvKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h b/paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h index 2d41d04f43c49..8f2315fc9d8ae 100644 --- a/paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h @@ -29,7 +29,7 @@ void ErfinvGradKernel(const Context& ctx, auto eigen_dout = EigenVector::Flatten(out_grad); auto eigen_dx = EigenVector::Flatten(*x_grad); auto& place = *ctx.eigen_device(); - constexpr T half_sqrt_pi = static_cast(1 / M_2_SQRTPI); + T half_sqrt_pi = static_cast(1 / M_2_SQRTPI); eigen_dx.device(place) = half_sqrt_pi * eigen_dout * eigen_out.square().exp(); } diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2dcfcb5b08ff6..dd5337dc4a9bc 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4760,7 +4760,7 @@ def erfinv(x, name=None): erfinv(erf(x)) = x. Args: - x (Tensor): An N-D Tensor, the data type is float32, float64. + x (Tensor): An N-D Tensor, the data type is float16, bfloat16, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -4779,7 +4779,9 @@ def erfinv(x, name=None): if in_dynamic_mode(): return _C_ops.erfinv(x) else: - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'erfinv') + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'float16', 'uint16'], 'erfinv' + ) helper = LayerHelper('erfinv', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type='erfinv', inputs={'X': x}, outputs={'Out': out}) diff --git a/test/legacy_test/test_erfinv_op.py b/test/legacy_test/test_erfinv_op.py index 95f1f3621dc5e..f6a51285de6e2 100644 --- a/test/legacy_test/test_erfinv_op.py +++ b/test/legacy_test/test_erfinv_op.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) from scipy.special import erfinv import paddle @@ -25,7 +29,7 @@ np.random.seed(0) -class TestErfinv(OpTest): +class TestErfinvOp(OpTest): def setUp(self): self.op_type = "erfinv" self.python_api = paddle.erfinv @@ -55,12 +59,12 @@ def test_check_grad(self): ) -class TestErfinvFP32(TestErfinv): +class TestErfinvFP64Op(TestErfinvOp): def init_dtype(self): - self.dtype = np.float32 + self.dtype = np.float64 -class TestErfinvAPI(unittest.TestCase): +class TestErfinvAPIOp(unittest.TestCase): def init_dtype(self): self.dtype = 'float32' @@ -110,5 +114,49 @@ def run(place): run(place) +class TestErfinvFP16Op(TestErfinvOp): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestErfinvBF16Op(OpTest): + def setUp(self): + self.op_type = "erfinv" + self.public_python_api = paddle.erfinv + self.python_api = paddle.erfinv + self.dtype = np.uint16 + self.shape = [11, 17] + self.datatype = np.float32 + self.input_data = np.random.uniform(-1, 1, size=self.shape).astype( + self.datatype + ) + self.inputs = {'X': convert_float_to_uint16(self.input_data)} + self.inputs_data = convert_uint16_to_float(self.inputs['X']) + out_ref = erfinv(self.input_data) + self.grad_out = np.ones(self.shape, self.datatype) + self.gradient = ( + np.sqrt(np.pi) / 2 * np.exp(np.square(out_ref)) * self.grad_out + ) + + self.outputs = {'Out': convert_float_to_uint16(out_ref)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + 'Out', + ) + + if __name__ == "__main__": unittest.main()