Skip to content

Commit

Permalink
Add FP16 & BF16 for nanmedian (#56056)
Browse files Browse the repository at this point in the history
  • Loading branch information
enkilee authored Aug 9, 2023
1 parent 08e46d6 commit 4ae9945
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,5 @@ PD_REGISTER_KERNEL(nanmedian_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/nanmedian_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(nanmedian,
double,
int,
int64_t,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
4 changes: 2 additions & 2 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
the average value of both elements in the middle is calculated as the median.
Args:
x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64.
x (Tensor): The input Tensor, it's data type can be int32, int64, float16, bfloat16, float32, float64.
axis (None|int|list|tuple, optional):
The axis along which to perform median calculations ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
Expand Down Expand Up @@ -319,7 +319,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
check_variable_and_dtype(
x,
'X',
['int32', 'int64', 'float16', 'float32', 'float64'],
['int32', 'int64', 'float16', 'float32', 'float64', 'uint16'],
'nanmedian',
)

Expand Down
46 changes: 46 additions & 0 deletions test/legacy_test/test_nanmedian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core
Expand Down Expand Up @@ -243,5 +244,50 @@ def test_check_grad_0d(self):
np.testing.assert_allclose(x.grad, np.array(0.0))


class TestNanmedianFP16Op(OpTest):
def setUp(self):
self.op_type = "nanmedian"
self.python_api = paddle.nanmedian
self.public_python_api = paddle.nanmedian
self.dtype = np.float16
self.python_out_sig = ["Out"]
X = np.random.random((100, 100)).astype('float16')
Out = np.nanmedian(X)
self.inputs = {'X': X}
self.outputs = {'Out': Out}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestNanmedianBF16Op(OpTest):
def setUp(self):
self.op_type = "nanmedian"
self.python_api = paddle.nanmedian
self.public_python_api = paddle.nanmedian
self.dtype = np.uint16
self.python_out_sig = ["Out"]
X = np.random.random((100, 100)).astype('float32')
Out = np.nanmedian(X)
self.inputs = {'X': convert_float_to_uint16(X)}
self.outputs = {'Out': convert_float_to_uint16(Out)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')


if __name__ == "__main__":
unittest.main()

0 comments on commit 4ae9945

Please sign in to comment.