From d8d29f15bd6c6e8220b187fd65bea7c3eda1b8a0 Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Tue, 14 Mar 2023 22:05:56 +0800 Subject: [PATCH 01/14] add fp16 and bf16 support for poisson --- paddle/phi/kernels/gpu/poisson_grad_kernel.cu | 11 ++- paddle/phi/kernels/gpu/poisson_kernel.cu | 11 ++- paddle/phi/kernels/poisson_kernel.h | 1 + .../fluid/tests/unittests/test_poisson_op.py | 67 ++++++++++++++++++- 4 files changed, 85 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/poisson_grad_kernel.cu b/paddle/phi/kernels/gpu/poisson_grad_kernel.cu index 8c16bc51fffe5..1990ba1fb7f48 100644 --- a/paddle/phi/kernels/gpu/poisson_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/poisson_grad_kernel.cu @@ -12,8 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/poisson_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - poisson_grad, GPU, ALL_LAYOUT, phi::PoissonGradKernel, float, double) {} +PD_REGISTER_KERNEL(poisson_grad, + GPU, + ALL_LAYOUT, + phi::PoissonGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/poisson_kernel.cu b/paddle/phi/kernels/gpu/poisson_kernel.cu index 302a9fe5ce581..3ba019b83a3f7 100644 --- a/paddle/phi/kernels/gpu/poisson_kernel.cu +++ b/paddle/phi/kernels/gpu/poisson_kernel.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/poisson_kernel.h" @@ -64,5 +65,11 @@ void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { } // namespace phi -PD_REGISTER_KERNEL( - poisson, GPU, ALL_LAYOUT, phi::PoissonKernel, float, double) {} +PD_REGISTER_KERNEL(poisson, + GPU, + ALL_LAYOUT, + phi::PoissonKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/poisson_kernel.h b/paddle/phi/kernels/poisson_kernel.h index b2b2ea97f014e..012b50378ccfc 100644 --- a/paddle/phi/kernels/poisson_kernel.h +++ b/paddle/phi/kernels/poisson_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index ee66d578014c7..c440645a4a295 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -16,9 +16,10 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle +import paddle.fluid.core as core paddle.enable_static() paddle.seed(100) @@ -368,5 +369,69 @@ def test_fixed_random_number(self): paddle.enable_static() +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the float16", +) +class TestPoissonFP16OP(OpTest): + def setUp(self): + self.op_type = "poisson" + self.python_api = paddle.tensor.poisson + self.config() + self.__class__.op_type = self.op_type + x = np.full([2048, 1024], self.lam, dtype=self.dtype) + out = np.ones([2048, 1024]) + self.attrs = {} + self.inputs = {'X': x.astype(self.dtype)} + self.outputs = {'Out': out} + + def config(self): + self.lam = 10 + self.a = 5 + self.b = 15 + self.dtype = np.float16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', max_relative_error=1e-2) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestPoissonBF16(OpTest): + def setUp(self): + self.op_type = "poisson" + self.python_api = paddle.tensor.poisson + self.config() + self.__class__.op_type = self.op_type + x = np.full([2048, 1024], self.lam, dtype=self.dtype) + out = np.ones([2048, 1024]) + self.attrs = {} + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def config(self): + self.lam = 10 + self.a = 5 + self.b = 15 + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', max_relative_error=1e-2) + + if __name__ == "__main__": unittest.main() From 0d697559cb89bac1f9f8f8f750eb0b4aa2d28f4e Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Tue, 14 Mar 2023 22:31:43 +0800 Subject: [PATCH 02/14] add fp16 and bf16 support for searchsorted --- paddle/phi/kernels/gpu/searchsorted_kernel.cu | 5 +- paddle/phi/kernels/searchsorted_kernel.h | 1 + .../tests/unittests/test_searchsorted_op.py | 72 ++++++++++++++++++- 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/searchsorted_kernel.cu b/paddle/phi/kernels/gpu/searchsorted_kernel.cu index 4a2ce2241c22d..58fd518598e49 100644 --- a/paddle/phi/kernels/gpu/searchsorted_kernel.cu +++ b/paddle/phi/kernels/gpu/searchsorted_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/searchsorted_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h" @@ -25,4 +26,6 @@ PD_REGISTER_KERNEL(searchsorted, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/searchsorted_kernel.h b/paddle/phi/kernels/searchsorted_kernel.h index e425c7fd79555..953caf9466d7e 100644 --- a/paddle/phi/kernels/searchsorted_kernel.h +++ b/paddle/phi/kernels/searchsorted_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index c83e8d534463f..79801d5576778 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid.core as core @@ -215,5 +215,75 @@ def test_sortedsequence_values_type_error(): self.assertRaises(TypeError, test_sortedsequence_values_type_error) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the float16", +) +class TestSearchSortedFP16OP(OpTest): + def setUp(self): + self.python_api = paddle.searchsorted + self.op_type = "searchsorted" + self.__class__.op_type = self.op_type + self.dtype = np.float16 + self.init_test_case() + + sorted_sequence = self.sorted_sequence + values = self.values + out = np.searchsorted(self.sorted_sequence, self.values, side=self.side) + + self.inputs = { + 'SortedSequence': sorted_sequence.astype(self.dtype), + 'Values': values.astype(self.dtype), + } + self.attrs = {"out_int32": False, "right": False} + self.attrs["right"] = True if self.side == 'right' else False + self.outputs = {'Out': out} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3, check_eager=True) + + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(np.float32) + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype(np.float32) + self.side = "left" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestSearchSortedBF16(OpTest): + def setUp(self): + self.python_api = paddle.searchsorted + self.op_type = "searchsorted" + self.__class__.op_type = self.op_type + self.dtype = np.uint16 + self.init_test_case() + + sorted_sequence = self.sorted_sequence + values = self.values + out = np.searchsorted(self.sorted_sequence, self.values, side=self.side) + + self.inputs = { + 'SortedSequence': convert_float_to_uint16(sorted_sequence), + 'Values': convert_float_to_uint16(values), + } + self.attrs = {"out_int32": False, "right": False} + self.attrs["right"] = True if self.side == 'right' else False + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3, check_eager=True) + + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(np.float32) + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype(np.float32) + self.side = "left" + + if __name__ == '__main__': unittest.main() From 2718d8d1c58de9e8c57d55ba9a9333ffdb516cfa Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Sun, 19 Mar 2023 22:56:06 +0800 Subject: [PATCH 03/14] fix bug --- paddle/phi/kernels/gpu/poisson_grad_kernel.cu | 1 - paddle/phi/kernels/gpu/poisson_kernel.cu | 1 - paddle/phi/kernels/gpu/searchsorted_kernel.cu | 1 - paddle/phi/kernels/poisson_kernel.h | 1 - paddle/phi/kernels/searchsorted_kernel.h | 1 - .../fluid/tests/unittests/test_poisson_op.py | 40 +++--------- .../tests/unittests/test_searchsorted_op.py | 61 ++++++------------- 7 files changed, 26 insertions(+), 80 deletions(-) diff --git a/paddle/phi/kernels/gpu/poisson_grad_kernel.cu b/paddle/phi/kernels/gpu/poisson_grad_kernel.cu index 1990ba1fb7f48..be7d28a6630cc 100644 --- a/paddle/phi/kernels/gpu/poisson_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/poisson_grad_kernel.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/poisson_grad_kernel_impl.h" diff --git a/paddle/phi/kernels/gpu/poisson_kernel.cu b/paddle/phi/kernels/gpu/poisson_kernel.cu index 3ba019b83a3f7..1d1968b30ae6e 100644 --- a/paddle/phi/kernels/gpu/poisson_kernel.cu +++ b/paddle/phi/kernels/gpu/poisson_kernel.cu @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/poisson_kernel.h" diff --git a/paddle/phi/kernels/gpu/searchsorted_kernel.cu b/paddle/phi/kernels/gpu/searchsorted_kernel.cu index 69b31803c7894..abfdcbd0e27ea 100644 --- a/paddle/phi/kernels/gpu/searchsorted_kernel.cu +++ b/paddle/phi/kernels/gpu/searchsorted_kernel.cu @@ -15,7 +15,6 @@ #include "paddle/phi/kernels/searchsorted_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h" diff --git a/paddle/phi/kernels/poisson_kernel.h b/paddle/phi/kernels/poisson_kernel.h index 012b50378ccfc..b2b2ea97f014e 100644 --- a/paddle/phi/kernels/poisson_kernel.h +++ b/paddle/phi/kernels/poisson_kernel.h @@ -15,7 +15,6 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/paddle/phi/kernels/searchsorted_kernel.h b/paddle/phi/kernels/searchsorted_kernel.h index 953caf9466d7e..e425c7fd79555 100644 --- a/paddle/phi/kernels/searchsorted_kernel.h +++ b/paddle/phi/kernels/searchsorted_kernel.h @@ -15,7 +15,6 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index c440645a4a295..9b647cf6af385 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -43,17 +43,20 @@ class TestPoissonOp1(OpTest): def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson + self.init_dtype() self.config() self.attrs = {} self.inputs = {'X': np.full([2048, 1024], self.lam, dtype=self.dtype)} self.outputs = {'Out': np.ones([2048, 1024], dtype=self.dtype)} + def init_dtype(self): + self.dtype = "float64" + def config(self): self.lam = 10 self.a = 5 self.b = 15 - self.dtype = "float64" def verify_output(self, outs): hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b) @@ -369,37 +372,10 @@ def test_fixed_random_number(self): paddle.enable_static() -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_float16_supported(core.CUDAPlace(0)), - "core is not complied with CUDA and not support the float16", -) -class TestPoissonFP16OP(OpTest): - def setUp(self): - self.op_type = "poisson" - self.python_api = paddle.tensor.poisson - self.config() - self.__class__.op_type = self.op_type - x = np.full([2048, 1024], self.lam, dtype=self.dtype) - out = np.ones([2048, 1024]) - self.attrs = {} - self.inputs = {'X': x.astype(self.dtype)} - self.outputs = {'Out': out} - - def config(self): - self.lam = 10 - self.a = 5 - self.b = 15 +class TestPoissonFP16OP(TestPoissonOp1): + def init_dtype(self): self.dtype = np.float16 - def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) - - def test_check_grad(self): - place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', max_relative_error=1e-2) - @unittest.skipIf( not core.is_compiled_with_cuda() @@ -426,11 +402,11 @@ def config(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', max_relative_error=1e-2) + self.check_grad_with_place(place, ['X'], 'Out') if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 79801d5576778..434bb6538511a 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -27,11 +27,12 @@ class TestSearchSorted(OpTest): def setUp(self): self.python_api = paddle.searchsorted self.op_type = "searchsorted" + self.init_dtype() self.init_test_case() self.inputs = { - 'SortedSequence': self.sorted_sequence, - 'Values': self.values, + 'SortedSequence': self.sorted_sequence.astype(self.dtype), + 'Values': self.values.astype(self.dtype), } self.attrs = {"out_int32": False, "right": False} self.attrs["right"] = True if self.side == 'right' else False @@ -41,6 +42,9 @@ def setUp(self): ) } + def init_dtype(self): + self.dtype = "float32" + def test_check_output(self): self.check_output(check_eager=True) @@ -215,39 +219,9 @@ def test_sortedsequence_values_type_error(): self.assertRaises(TypeError, test_sortedsequence_values_type_error) -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_float16_supported(core.CUDAPlace(0)), - "core is not complied with CUDA and not support the float16", -) -class TestSearchSortedFP16OP(OpTest): - def setUp(self): - self.python_api = paddle.searchsorted - self.op_type = "searchsorted" - self.__class__.op_type = self.op_type +class TestSearchSortedFP16OP(TestSearchSorted): + def init_dtype(self): self.dtype = np.float16 - self.init_test_case() - - sorted_sequence = self.sorted_sequence - values = self.values - out = np.searchsorted(self.sorted_sequence, self.values, side=self.side) - - self.inputs = { - 'SortedSequence': sorted_sequence.astype(self.dtype), - 'Values': values.astype(self.dtype), - } - self.attrs = {"out_int32": False, "right": False} - self.attrs["right"] = True if self.side == 'right' else False - self.outputs = {'Out': out} - - def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=True) - - def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(np.float32) - self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype(np.float32) - self.side = "left" @unittest.skipIf( @@ -259,25 +233,26 @@ class TestSearchSortedBF16(OpTest): def setUp(self): self.python_api = paddle.searchsorted self.op_type = "searchsorted" - self.__class__.op_type = self.op_type self.dtype = np.uint16 self.init_test_case() - sorted_sequence = self.sorted_sequence - values = self.values - out = np.searchsorted(self.sorted_sequence, self.values, side=self.side) - self.inputs = { - 'SortedSequence': convert_float_to_uint16(sorted_sequence), - 'Values': convert_float_to_uint16(values), + 'SortedSequence': convert_float_to_uint16(self.sorted_sequence), + 'Values': convert_float_to_uint16(self.values), } self.attrs = {"out_int32": False, "right": False} self.attrs["right"] = True if self.side == 'right' else False - self.outputs = {'Out': convert_float_to_uint16(out)} + self.outputs = { + 'Out': convert_float_to_uint16( + np.searchsorted( + self.sorted_sequence, self.values, side=self.side + ) + ) + } def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=True) + self.check_output_with_place(place) def init_test_case(self): self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(np.float32) From 9409af75dcd4928169426e4965ef2517666757f1 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Tue, 28 Mar 2023 14:08:09 +0800 Subject: [PATCH 04/14] Update test_searchsorted_op.py fix function name --- python/paddle/fluid/tests/unittests/test_searchsorted_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 52ba099f10f62..c2add40c0e73c 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -230,7 +230,7 @@ def init_dtype(self): or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not complied with CUDA and not support the bfloat16", ) -class TestSearchSortedBF16(OpTest): +class TestSearchSortedBF16Op(OpTest): def setUp(self): self.python_api = paddle.searchsorted self.op_type = "searchsorted" From 6c7ac27ac97ff10eabab26f0e383fd63127b7575 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Tue, 28 Mar 2023 14:08:58 +0800 Subject: [PATCH 05/14] Update test_poisson_op.py fix function name --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index 9b647cf6af385..cb7a4aa518a30 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -382,7 +382,7 @@ def init_dtype(self): or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not complied with CUDA and not support the bfloat16", ) -class TestPoissonBF16(OpTest): +class TestPoissonBF16Op(OpTest): def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson From 8f280ee6633522357130501599ab40c8884fb043 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Mon, 3 Apr 2023 19:52:48 +0800 Subject: [PATCH 06/14] fix bug --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 2 +- python/paddle/fluid/tests/unittests/test_searchsorted_op.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index cb7a4aa518a30..e9bc0f364c460 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -19,7 +19,7 @@ from eager_op_test import OpTest, convert_float_to_uint16 import paddle -import paddle.fluid.core as core +from paddle.fluid import core paddle.enable_static() paddle.seed(100) diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index c2add40c0e73c..8c281c5dd68e8 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -17,7 +17,6 @@ import numpy as np from eager_op_test import OpTest, convert_float_to_uint16 - import paddle from paddle.fluid import core From 26e4f919c531b8cbab62f7d9ddad0b968e0dddbf Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Sat, 22 Apr 2023 20:26:23 +0800 Subject: [PATCH 07/14] remove the searchorted --- paddle/phi/kernels/gpu/searchsorted_kernel.cu | 4 +- .../tests/unittests/test_searchsorted_op.py | 51 ++----------------- 2 files changed, 4 insertions(+), 51 deletions(-) diff --git a/paddle/phi/kernels/gpu/searchsorted_kernel.cu b/paddle/phi/kernels/gpu/searchsorted_kernel.cu index abfdcbd0e27ea..b6d6a795a59e7 100644 --- a/paddle/phi/kernels/gpu/searchsorted_kernel.cu +++ b/paddle/phi/kernels/gpu/searchsorted_kernel.cu @@ -25,8 +25,6 @@ PD_REGISTER_KERNEL(searchsorted, float, double, int, - int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) { + int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 8c281c5dd68e8..fa194cab5b1b5 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest, convert_float_to_uint16 +from eager_op_test import OpTest import paddle from paddle.fluid import core @@ -27,12 +27,11 @@ class TestSearchSorted(OpTest): def setUp(self): self.python_api = paddle.searchsorted self.op_type = "searchsorted" - self.init_dtype() self.init_test_case() self.inputs = { - 'SortedSequence': self.sorted_sequence.astype(self.dtype), - 'Values': self.values.astype(self.dtype), + 'SortedSequence': self.sorted_sequence, + 'Values': self.values, } self.attrs = {"out_int32": False, "right": False} self.attrs["right"] = True if self.side == 'right' else False @@ -42,9 +41,6 @@ def setUp(self): ) } - def init_dtype(self): - self.dtype = "float32" - def test_check_output(self): self.check_output() @@ -219,46 +215,5 @@ def test_sortedsequence_values_type_error(): self.assertRaises(TypeError, test_sortedsequence_values_type_error) -class TestSearchSortedFP16OP(TestSearchSorted): - def init_dtype(self): - self.dtype = np.float16 - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not complied with CUDA and not support the bfloat16", -) -class TestSearchSortedBF16Op(OpTest): - def setUp(self): - self.python_api = paddle.searchsorted - self.op_type = "searchsorted" - self.dtype = np.uint16 - self.init_test_case() - - self.inputs = { - 'SortedSequence': convert_float_to_uint16(self.sorted_sequence), - 'Values': convert_float_to_uint16(self.values), - } - self.attrs = {"out_int32": False, "right": False} - self.attrs["right"] = True if self.side == 'right' else False - self.outputs = { - 'Out': convert_float_to_uint16( - np.searchsorted( - self.sorted_sequence, self.values, side=self.side - ) - ) - } - - def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place) - - def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(np.float32) - self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype(np.float32) - self.side = "left" - - if __name__ == '__main__': unittest.main() From 1c2419b366bce8312de20c326d9b9c350fce0077 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Wed, 3 May 2023 00:01:22 +0800 Subject: [PATCH 08/14] Update test_poisson_op.py --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index e9bc0f364c460..a301deaf2f382 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -387,9 +387,8 @@ def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson self.config() - self.__class__.op_type = self.op_type - x = np.full([2048, 1024], self.lam, dtype=self.dtype) - out = np.ones([2048, 1024]) + x = np.full([248, 124], self.lam, dtype=self.dtype) + out = np.ones([248, 124]) self.attrs = {} self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': convert_float_to_uint16(out)} From 0525ca4054fd9b0068a6833634a28d43eff04f7a Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Wed, 10 May 2023 14:32:35 +0800 Subject: [PATCH 09/14] fix bug of TestPoissonBF16Op --- .../fluid/tests/unittests/test_poisson_op.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index a301deaf2f382..fb575f3737856 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -399,13 +399,25 @@ def config(self): self.b = 15 self.dtype = np.uint16 + def verify_output(self, outs): + hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b) + np.testing.assert_allclose(hist, prob, rtol=0.01) + def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place_customized(self.verify_output, place) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') + self.check_grad_with_place( + place, + ['X'], + 'Out', + user_defined_grads=[np.zeros([2048, 1024], dtype="float32")], + user_defined_grad_outputs=[ + np.random.rand(2048, 1024).astype("float32") + ], + ) if __name__ == "__main__": From 748612f7ed0077c3101c8d2a20278ebca1c391c5 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Thu, 11 May 2023 22:18:03 +0800 Subject: [PATCH 10/14] Update test_poisson_op.py --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index fb575f3737856..7b04296b40d71 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -387,8 +387,8 @@ def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson self.config() - x = np.full([248, 124], self.lam, dtype=self.dtype) - out = np.ones([248, 124]) + x = np.full([2048, 1024], self.lam, dtype=self.dtype) + out = np.ones([2048, 1024]) self.attrs = {} self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': convert_float_to_uint16(out)} From fded5548c89468c72b78ddec7294380bd3eb982f Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Fri, 12 May 2023 16:32:01 +0800 Subject: [PATCH 11/14] Update test_poisson_op.py --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index 7b04296b40d71..d786af09ac35c 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -387,8 +387,8 @@ def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson self.config() - x = np.full([2048, 1024], self.lam, dtype=self.dtype) - out = np.ones([2048, 1024]) + x = np.full([2048, 1024], self.lam, dtype="float32") + out = np.ones([2048, 1024], dtype="float32") self.attrs = {} self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': convert_float_to_uint16(out)} From 229cc555814b43abce899eb4543399c8a73da105 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Mon, 15 May 2023 22:23:29 +0800 Subject: [PATCH 12/14] Update test_poisson_op.py --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index d786af09ac35c..00283ee25c8a4 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -389,6 +389,8 @@ def setUp(self): self.config() x = np.full([2048, 1024], self.lam, dtype="float32") out = np.ones([2048, 1024], dtype="float32") + x = convert_uint16_to_float(convert_float_to_uint16(x)) + out = convert_uint16_to_float(convert_float_to_uint16(out)) self.attrs = {} self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': convert_float_to_uint16(out)} From 37a69bdaafe5705dcfd133e8e8986598bb532110 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Tue, 16 May 2023 09:58:11 +0800 Subject: [PATCH 13/14] fix bug of import --- python/paddle/fluid/tests/unittests/test_poisson_op.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index 00283ee25c8a4..06c4cf8590934 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -16,7 +16,11 @@ import unittest import numpy as np -from eager_op_test import OpTest, convert_float_to_uint16 +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) import paddle from paddle.fluid import core @@ -386,6 +390,7 @@ class TestPoissonBF16Op(OpTest): def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson + self.__class__.op_type = self.op_type self.config() x = np.full([2048, 1024], self.lam, dtype="float32") out = np.ones([2048, 1024], dtype="float32") From 5bda810183ef6b32df69ac41c64ce4321b0f2170 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Fri, 28 Jul 2023 15:48:29 +0800 Subject: [PATCH 14/14] fix bug --- test/legacy_test/test_poisson_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/legacy_test/test_poisson_op.py b/test/legacy_test/test_poisson_op.py index 06c4cf8590934..84edf6a322189 100644 --- a/test/legacy_test/test_poisson_op.py +++ b/test/legacy_test/test_poisson_op.py @@ -394,8 +394,6 @@ def setUp(self): self.config() x = np.full([2048, 1024], self.lam, dtype="float32") out = np.ones([2048, 1024], dtype="float32") - x = convert_uint16_to_float(convert_float_to_uint16(x)) - out = convert_uint16_to_float(convert_float_to_uint16(out)) self.attrs = {} self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': convert_float_to_uint16(out)} @@ -407,7 +405,9 @@ def config(self): self.dtype = np.uint16 def verify_output(self, outs): - hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b) + hist, prob = output_hist( + convert_uint16_to_float(np.array(outs[0])), self.lam, self.a, self.b + ) np.testing.assert_allclose(hist, prob, rtol=0.01) def test_check_output(self):