Skip to content

Commit

Permalink
Add FP16 & BF16 for erfinv (#55287)
Browse files Browse the repository at this point in the history
  • Loading branch information
enkilee authored Aug 2, 2023
1 parent 19da5c0 commit 6d7efd0
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 11 deletions.
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/erfinv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
erfinv_grad, GPU, ALL_LAYOUT, phi::ErfinvGradKernel, float, double) {}
PD_REGISTER_KERNEL(erfinv_grad,
GPU,
ALL_LAYOUT,
phi::ErfinvGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
23 changes: 22 additions & 1 deletion paddle/phi/kernels/gpu/erfinv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,21 @@ template <typename T>
struct ErfinvFunctor {
HOSTDEVICE inline T operator()(const T x) const { return erfinv(x); }
};
template <>
struct ErfinvFunctor<float16> {
HOSTDEVICE inline float16 operator()(const float16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<float16>(erfinv(x_));
}
};

template <>
struct ErfinvFunctor<bfloat16> {
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<bfloat16>(erfinv(x_));
}
};
template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
Expand All @@ -34,4 +48,11 @@ void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {

} // namespace phi

PD_REGISTER_KERNEL(erfinv, GPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}
PD_REGISTER_KERNEL(erfinv,
GPU,
ALL_LAYOUT,
phi::ErfinvKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void ErfinvGradKernel(const Context& ctx,
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
auto eigen_dx = EigenVector<T>::Flatten(*x_grad);
auto& place = *ctx.eigen_device();
constexpr T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
eigen_dx.device(place) = half_sqrt_pi * eigen_dout * eigen_out.square().exp();
}

Expand Down
6 changes: 4 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4760,7 +4760,7 @@ def erfinv(x, name=None):
erfinv(erf(x)) = x.
Args:
x (Tensor): An N-D Tensor, the data type is float32, float64.
x (Tensor): An N-D Tensor, the data type is float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand All @@ -4779,7 +4779,9 @@ def erfinv(x, name=None):
if in_dynamic_mode():
return _C_ops.erfinv(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'erfinv')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], 'erfinv'
)
helper = LayerHelper('erfinv', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='erfinv', inputs={'X': x}, outputs={'Out': out})
Expand Down
58 changes: 53 additions & 5 deletions test/legacy_test/test_erfinv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
from scipy.special import erfinv

import paddle
Expand All @@ -25,7 +29,7 @@
np.random.seed(0)


class TestErfinv(OpTest):
class TestErfinvOp(OpTest):
def setUp(self):
self.op_type = "erfinv"
self.python_api = paddle.erfinv
Expand Down Expand Up @@ -55,12 +59,12 @@ def test_check_grad(self):
)


class TestErfinvFP32(TestErfinv):
class TestErfinvFP64Op(TestErfinvOp):
def init_dtype(self):
self.dtype = np.float32
self.dtype = np.float64


class TestErfinvAPI(unittest.TestCase):
class TestErfinvAPIOp(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float32'

Expand Down Expand Up @@ -110,5 +114,49 @@ def run(place):
run(place)


class TestErfinvFP16Op(TestErfinvOp):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestErfinvBF16Op(OpTest):
def setUp(self):
self.op_type = "erfinv"
self.public_python_api = paddle.erfinv
self.python_api = paddle.erfinv
self.dtype = np.uint16
self.shape = [11, 17]
self.datatype = np.float32
self.input_data = np.random.uniform(-1, 1, size=self.shape).astype(
self.datatype
)
self.inputs = {'X': convert_float_to_uint16(self.input_data)}
self.inputs_data = convert_uint16_to_float(self.inputs['X'])
out_ref = erfinv(self.input_data)
self.grad_out = np.ones(self.shape, self.datatype)
self.gradient = (
np.sqrt(np.pi) / 2 * np.exp(np.square(out_ref)) * self.grad_out
)

self.outputs = {'Out': convert_float_to_uint16(out_ref)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6d7efd0

Please sign in to comment.