From 3993fb4b0c19d9226a31b16bde0f986acb9db74d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 2 Dec 2024 11:22:14 +0800 Subject: [PATCH] update --- .../paddle_scatter/composite/logsumexp.py | 2 +- .../paddle_scatter/composite/softmax.py | 2 +- .../paddle_scatter/composite/std.py | 2 +- .../paddle_scatter/csrc/atomics.cuh | 31 + .../paddle_scatter/csrc/index_info.cuh | 7 +- .../paddle_scatter/csrc/index_info.h | 7 +- .../paddle_scatter/csrc/scatter_min_max.cu | 2 +- .../paddle_scatter/csrc/segment_coo.cc | 500 ++++++++++++++ .../paddle_scatter/csrc/segment_coo.cu | 610 ++++++++++++++++++ .../paddle_scatter/csrc/segment_csr.cc | 409 ++++++++++++ .../paddle_scatter/csrc/segment_csr.cu | 382 +++++++++++ .../paddle_scatter/csrc/utils.cuh | 2 +- jointContribution/paddle_scatter/scatter.py | 73 ++- .../paddle_scatter/segment_coo.py | 195 +++--- .../paddle_scatter/segment_csr.py | 161 ++--- jointContribution/paddle_scatter/setup.py | 4 +- jointContribution/paddle_scatter/testing.py | 9 +- .../tests/composite/test_softmax.py | 2 +- .../paddle_scatter/tests/test_gather.py | 115 ++++ .../paddle_scatter/tests/test_multi_gpu.py | 4 +- .../paddle_scatter/tests/test_scatter.py | 122 +++- .../paddle_scatter/tests/test_segment.py | 147 +++++ jointContribution/paddle_scatter/utils.py | 37 -- 23 files changed, 2501 insertions(+), 324 deletions(-) create mode 100644 jointContribution/paddle_scatter/csrc/segment_coo.cc create mode 100644 jointContribution/paddle_scatter/csrc/segment_coo.cu create mode 100644 jointContribution/paddle_scatter/csrc/segment_csr.cc create mode 100644 jointContribution/paddle_scatter/csrc/segment_csr.cu diff --git a/jointContribution/paddle_scatter/composite/logsumexp.py b/jointContribution/paddle_scatter/composite/logsumexp.py index 3d66d040a..a3bed5c09 100644 --- a/jointContribution/paddle_scatter/composite/logsumexp.py +++ b/jointContribution/paddle_scatter/composite/logsumexp.py @@ -42,7 +42,7 @@ def scatter_logsumexp( ) index = broadcast(index, src, dim) - eps = paddle.to_tensor(eps, dtype=src.dtype) + eps = paddle.full([], eps, dtype=src.dtype) if out is not None: dim_size = out.shape[dim] diff --git a/jointContribution/paddle_scatter/composite/softmax.py b/jointContribution/paddle_scatter/composite/softmax.py index b31b54d44..dfa2eaa8c 100644 --- a/jointContribution/paddle_scatter/composite/softmax.py +++ b/jointContribution/paddle_scatter/composite/softmax.py @@ -84,7 +84,7 @@ def scatter_log_softmax( ) index = broadcast(index, src, dim) - eps = paddle.to_tensor(eps, dtype=src.dtype) + eps = paddle.full([], eps, dtype=src.dtype) max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) diff --git a/jointContribution/paddle_scatter/composite/std.py b/jointContribution/paddle_scatter/composite/std.py index 359e0b50d..6d27f4ad4 100644 --- a/jointContribution/paddle_scatter/composite/std.py +++ b/jointContribution/paddle_scatter/composite/std.py @@ -57,7 +57,7 @@ def scatter_std( res = scatter_sum(var, index, dim, out, dim_size) if unbiased: - count = count.subtract(paddle.to_tensor(1, dtype=src.dtype)).clip(1) + count = count.subtract(paddle.full([], 1, dtype=src.dtype)).clip(1) res = res.divide(count + 1e-6).sqrt() if out is not None: diff --git a/jointContribution/paddle_scatter/csrc/atomics.cuh b/jointContribution/paddle_scatter/csrc/atomics.cuh index 224d29889..f5b5de08d 100644 --- a/jointContribution/paddle_scatter/csrc/atomics.cuh +++ b/jointContribution/paddle_scatter/csrc/atomics.cuh @@ -150,3 +150,34 @@ static inline __device__ void atomMin(float *address, float val) { static inline __device__ void atomMin(double *address, double val) { AtomicMinDecimalImpl()(address, val); } + +#define OP(X, Y) Y + X +ATOMIC(Add) +#undef OP +static inline __device__ void atomAdd(uint8_t *address, uint8_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int8_t *address, int8_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int16_t *address, int16_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int32_t *address, int32_t val) { + atomicAdd(address, val); +} +static inline __device__ void atomAdd(int64_t *address, int64_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(float *address, float val) { + atomicAdd(address, val); +} +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) +static inline __device__ void atomAdd(double *address, double val) { + AtomicAddDecimalImpl()(address, val); +} +#else +static inline __device__ void atomAdd(double *address, double val) { + atomicAdd(address, val); +} +#endif diff --git a/jointContribution/paddle_scatter/csrc/index_info.cuh b/jointContribution/paddle_scatter/csrc/index_info.cuh index 64ecd28de..f7bb665af 100644 --- a/jointContribution/paddle_scatter/csrc/index_info.cuh +++ b/jointContribution/paddle_scatter/csrc/index_info.cuh @@ -2,14 +2,14 @@ #include "paddle/extension.h" -#define MAX_TENSORINFO_DIMS 25 +#define MAX_TENSORINFO_DIMS 7 template struct TensorInfo { TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS], int st[MAX_TENSORINFO_DIMS]) { data = p; dims = dim; - PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25."); + PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 7."); for (int i = 0; i < dim; ++i) { sizes[i] = sz[i]; @@ -30,11 +30,8 @@ TensorInfo getTensorInfo(const paddle::Tensor &tensor) { int strides[MAX_TENSORINFO_DIMS]; int dims = tensor.shape().size(); - // int stride_i = 1; for (int i = dims - 1; i >= 0; --i) { sizes[i] = tensor.shape()[i]; - // strides[i] = stride_i; - // stride_i *= sizes[i]; sizes[i] = tensor.strides()[i]; } diff --git a/jointContribution/paddle_scatter/csrc/index_info.h b/jointContribution/paddle_scatter/csrc/index_info.h index c2751898b..3d27cacc7 100644 --- a/jointContribution/paddle_scatter/csrc/index_info.h +++ b/jointContribution/paddle_scatter/csrc/index_info.h @@ -2,14 +2,14 @@ #include "paddle/extension.h" -#define MAX_TENSORINFO_DIMS 25 +#define MAX_TENSORINFO_DIMS 7 template struct TensorInfo { TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS], int st[MAX_TENSORINFO_DIMS]) { data = p; dims = dim; - PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25."); + PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 7."); for (int i = 0; i < dim; ++i) { sizes[i] = sz[i]; @@ -29,11 +29,8 @@ TensorInfo getTensorInfo(const paddle::Tensor &tensor) { int strides[MAX_TENSORINFO_DIMS]; int dims = tensor.shape().size(); - // int stride_i = 1; for (int i = dims - 1; i >= 0; --i) { sizes[i] = tensor.shape()[i]; - // strides[i] = stride_i; - // stride_i *= sizes[i]; strides[i] = tensor.strides()[i]; } diff --git a/jointContribution/paddle_scatter/csrc/scatter_min_max.cu b/jointContribution/paddle_scatter/csrc/scatter_min_max.cu index 5fccb6f6a..089def620 100644 --- a/jointContribution/paddle_scatter/csrc/scatter_min_max.cu +++ b/jointContribution/paddle_scatter/csrc/scatter_min_max.cu @@ -135,7 +135,7 @@ std::vector scatter_min_max_cuda_forward(const paddle::Tensor& x using MPType = typename MPTypeTrait::Type; paddle::Tensor out_mp; if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { - out_mp = paddle::empty(return_shape, paddle::DataType::FLOAT32, x.place()); + out_mp = paddle::experimental::cast(out, paddle::DataType::FLOAT32); } else { out_mp = out; } diff --git a/jointContribution/paddle_scatter/csrc/segment_coo.cc b/jointContribution/paddle_scatter/csrc/segment_coo.cc new file mode 100644 index 000000000..f763d50e9 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_coo.cc @@ -0,0 +1,500 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void segment_coo_cpu_forward_kernel(const data_t* x_data, + const index_t* index_data, + const std::vector& return_shape, + const std::vector& x_dims, + const std::string& reduce, + const TensorInfo& index_info, + int64_t x_numel, + int64_t dim, + bool post_process, + data_t* out_data, + index_t* arg_out_data, + data_t* count_data) { + using MPType = typename MPTypeTrait::Type; + auto B = 1; + for (auto i = 0; i < dim; ++i) + B *= x_dims[i]; + auto E = x_dims[dim]; + auto K = x_numel / (B * E); + auto N = return_shape[dim]; + + auto stride = index_info.strides[index_info.dims - 1]; + std::vector args(K); + std::vector vals(K); + + int64_t idx, next_idx, row_start; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = static_cast(out_data[b * N * K + k]); + + row_start = 0; + for (auto e = 0; e < E; e++) { + // update + for (auto k = 0; k < K; k++) { + auto cmp = static_cast(x_data[b * E * K + e * K + k]); + if ((reduce == "min" && cmp < vals[k]) || + (reduce == "max" && cmp > vals[k])) { + vals[k] = cmp; + args[k] = e; + } else if (reduce == "sum" || reduce == "mean") { + vals[k] += cmp; + } + } + //write + if (e == E - 1) { + for (auto k = 0; k < K; k++) { + auto idx_k = b * N * K + idx * K + k; + auto count = E - row_start; + if (reduce == "sum") { + out_data[idx_k] = static_cast(vals[k]); + } else if (reduce == "mean") { + out_data[idx_k] = static_cast(vals[k] / static_cast(count > 0 ? count : 1)); + } else if (reduce == "min" || reduce == "max") { + if (count > 0) { + out_data[idx_k] = static_cast(vals[k]); + arg_out_data[idx_k] = args[k]; + } else { + out_data[idx_k] = static_cast(0); + } + } + } + + if (reduce == "mean") + count_data[b * N + idx] = static_cast(e + 1 - row_start); + } else { + next_idx = index_info.data[offset + (e + 1) * stride]; + assert(idx <= next_idx); + + if (idx != next_idx) { + //write + for (auto k = 0; k < K; k++) { + auto idx_k = b * N * K + idx * K + k; + auto count = e + 1 - row_start; + if (reduce == "sum") { + out_data[idx_k] = static_cast(vals[k]); + } else if (reduce == "mean") { + out_data[idx_k] = static_cast(vals[k] / static_cast(count > 0 ? count : 1)); + } else if (reduce == "min" || reduce == "max") { + if (count > 0) { + out_data[idx_k] = static_cast(vals[k]); + arg_out_data[idx_k] = args[k]; + } else { + out_data[idx_k] = static_cast(0); + } + } + + vals[k] = static_cast(out_data[b * N * K + next_idx * K + k]); + } + if (reduce == "mean") + count_data[b * N + idx] = static_cast(e + 1 - row_start); + row_start = e + 1; + } + + idx = next_idx; + } + } + } + + if (post_process) { + if (reduce == "min" || reduce == "max") { + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + for (auto i = 0; i < out_numel; ++i) { + if (out_data[i] == init_val) + out_data[i] = static_cast(0.0); + } + } + if (reduce == "mean") { + auto count_data_numel = sizeof(count_data) / sizeof(data_t); + for (auto i = 0; i < count_data_numel; ++i) { + if (count_data[i] < static_cast(1.0)) + count_data[i] = static_cast(1.0); + } + } + } +} + +std::vector segment_coo_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CPU(index); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + } else if (reduce == "mean") { + auto sizes = index.shape(); + sizes[dim] = return_shape[dim]; + arg_out = paddle::zeros(sizes, out.dtype(), index.place()); + } + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "segment_coo_cpu_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out, static_cast(std::numeric_limits::max())); + else if (reduce == "max") + paddle::experimental::fill_(out, static_cast(std::numeric_limits::lowest())); + else if (reduce == "sum" || reduce == "mean") + paddle::experimental::fill_(out, static_cast(0)); + } + + bool post_process = (!init) ? true : false; + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + segment_coo_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr, + (reduce == "mean") ? arg_out.data() : nullptr); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + segment_coo_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr, + (reduce == "mean") ? arg_out.data() : nullptr); + break; + } + default: + PD_THROW( + "function segment_coo_cpu_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + +template +void gather_coo_cpu_forward_kernel(const data_t* x_data, + const TensorInfo& index_info, + int stride, + int B, + int E, + int K, + int N, + data_t* out_data) { + std::vector vals(K); + int64_t idx, next_idx; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = x_data[b * N * K + idx * K + k]; + + for (auto e = 0; e < E; e++) { + for (auto k = 0; k < K; k++) + out_data[b * E * K + e * K + k] = vals[k]; + + if (e < E - 1) { + next_idx = index_info.data[offset + (e + 1) * stride]; + CHECK_INPUT(idx <= next_idx); + + if (idx != next_idx) { + idx = next_idx; + for (auto k = 0; k < K; k++) + vals[k] = x_data[b * N * K + idx * K + k]; + } + } + } + } +} + +std::vector gather_coo_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape) { + CHECK_CPU(index); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] == index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto B = index.numel() / return_shape[dim]; + auto E = index_dims[dim]; + auto K = out.numel() / index.numel(); + auto N = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "gather_coo_cpu_forward_kernel", ([&] { + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + gather_coo_cpu_forward_kernel( + x.data(), index_info, stride, + B, E, K, N, out.data()); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + gather_coo_cpu_forward_kernel( + x.data(), index_info, stride, + B, E, K, N, out.data()); + break; + } + default: + PD_THROW( + "function gather_coo_cpu_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out}; +} + +std::vector segment_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce); + +std::vector SegmentCooForward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce) { + if (x.is_cpu()) { + return segment_coo_cpu_forward(x, index, init, return_shape, reduce); + } else if (x.is_gpu()) { + return segment_coo_cuda_forward(x, index, init, return_shape, reduce); + } else { + PD_THROW("Unsupported device type for forward function of custom segment_coo operator."); + } +} + + +std::vector gather_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape); + +std::vector GatherCooForward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape) { + if (x.is_cpu()) { + return gather_coo_cpu_forward(x, index, init, return_shape); + } else if (x.is_gpu()) { + return gather_coo_cuda_forward(x, index, init, return_shape); + } else { + PD_THROW("Unsupported device type for forward function of custom gather_coo operator."); + } +} + +std::vector SegmentCooBackward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& arg_out, + const paddle::Tensor& grad_out, + std::string reduce) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom segment_coo operator."); + } + if (reduce == "min" || reduce == "max") { + int64_t dim = index.shape().size() - 1; + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out.get(), grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; + } else if (reduce == "mean") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + paddle::Tensor count = arg_out.get(); + paddle::Tensor grad_in = GatherCooForward(grad_out, index, paddle::optional(grad_x), grad_x.shape())[0]; + auto sizes = arg_out.get().shape(); + int64_t dim = index.shape().size() - 1; + sizes[dim] = index.shape()[dim]; + count = GatherCooForward(count, index, paddle::optional(paddle::none), sizes)[0]; + for (auto i = 0; i < grad_out.shape().size() - index.shape().size(); i++) + count = paddle::experimental::unsqueeze(count, {-1}); + paddle::experimental::divide_(grad_in, count); + return {grad_in}; + } else if (reduce == "sum") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + paddle::Tensor grad_in = GatherCooForward(grad_out, index, paddle::optional(grad_x), grad_x.shape())[0]; + return {grad_in}; + } +} + +std::vector GatherCooBackward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::Tensor& grad_out) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom gather_coo operator."); + } + auto x_shape = x.shape(); + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::Tensor grad_in = SegmentCooForward(grad_out, index, paddle::optional(grad_x), grad_x.shape(), "sum")[0]; + return {grad_in}; +} + +std::vector> SegmentCooFWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& init_shape, + std::vector return_shape, + std::string reduce) { + return {return_shape, return_shape}; +} + +std::vector SegmentCooFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& init_dtype) { + return {x_dtype, index_dtype}; +} + +std::vector> SegmentCooBWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& arg_out_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector SegmentCooBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& arg_out_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +std::vector> GatherCooFWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& init_shape, + std::vector return_shape) { + return {return_shape}; +} + +std::vector GatherCooFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& init_dtype) { + return {x_dtype}; +} + +std::vector> GatherCooBWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector GatherCooBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + + +PD_BUILD_OP(custom_segment_coo) + .Inputs({"X", "Index", paddle::Optional("Init")}) + .Outputs({"Out", paddle::Optional("ArgOut")}) + .Attrs({"return_shape: std::vector", + "reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCooForward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCooFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCooFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_segment_coo) + .Inputs({"X", "Index", paddle::Optional("ArgOut"), paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCooBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCooBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCooBWInferDtype)); + +PD_BUILD_OP(custom_gather_coo) + .Inputs({"X", "Index", paddle::Optional("Init")}) + .Outputs({"Out"}) + .Attrs({"return_shape: std::vector"}) + .SetKernelFn(PD_KERNEL(GatherCooForward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCooFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCooFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_gather_coo) + .Inputs({"X", "Index", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(GatherCooBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCooBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCooBWInferDtype)); diff --git a/jointContribution/paddle_scatter/csrc/segment_coo.cu b/jointContribution/paddle_scatter/csrc/segment_coo.cu new file mode 100644 index 000000000..959331128 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_coo.cu @@ -0,0 +1,610 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + + +enum ReductionType { MIN, MAX, SUM, MEAN }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX}, {"sum", SUM}, {"mean", MEAN} +}; + +bool is_floating_point(const phi::DataType& dtype) { + return dtype == phi::DataType::BFLOAT16 || dtype == phi::DataType::FLOAT16 || + dtype == phi::DataType::FLOAT32 || dtype == phi::DataType::FLOAT64; +} + +template +__global__ void +segment_coo_cuda_forward_kernel(const data_t* x_data, + const TensorInfo index_info, + ReductionType reduce_type, + mp_t* out_data, + size_t E, + size_t N) { + + // Each thread processes exactly one entry. Within a warp, we perform a + // parallel reduction across equal indices, and write the intermediate + // result via atomics. + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int lane_idx = row_idx & (32 - 1); + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = IndexToOffset::get(row_idx, index_info); + int64_t idx = index_info.data[offset], next_idx; + int out_idx = (row_idx / D) * N + idx; + + mp_t val = HAS_VAL ? static_cast(x_data[row_idx]) : (mp_t)1, tmp; + +#pragma unroll + for (int i = 1; i < 32; i *= 2) { + // Parallel reduction inside a single warp. + tmp = SHFL_UP_SYNC(FULL_MASK, val, i); + next_idx = SHFL_UP_SYNC(FULL_MASK, idx, i); + if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { + assert(idx >= next_idx); + if (idx == next_idx) { + // update + if ((reduce_type == MIN && tmp < val) || + (reduce_type == MAX && tmp > val)) + val = tmp; + else if (reduce_type == SUM || reduce_type == MEAN) + val += tmp; + } + } + } + + next_idx = SHFL_DOWN_SYNC(FULL_MASK, idx, 1); + if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || + idx != next_idx) { + // atomic_write + switch(reduce_type) { + case MIN: + atomMin(out_data + out_idx, val); + break; + case MAX: + atomMax(out_data + out_idx, val); + break; + case SUM: + atomAdd(out_data + out_idx, val); + break; + case MEAN: + atomAdd(out_data + out_idx, val); + break; + } + } + } +} + +template +__global__ void segment_coo_broadcast_cuda_forward_kernel(const data_t *x_data, + const TensorInfo index_info, + ReductionType reduce_type, + mp_t *out_data, + size_t E, + size_t K, + size_t N) { + + // Each thread processes a single column and `TB` index entries. Coalesced + // read and write is performed in column-major order. The intermediate + // results are written via atomics. + + int D = index_info.sizes[index_info.dims - 1]; + int E_1 = E / D; + int E_2 = (D - 1) + TB - ((D - 1) % TB); + + int row_idx = blockIdx.x * blockDim.y + threadIdx.y; + int col_idx = blockIdx.y * blockDim.x + threadIdx.x; + + int dim_start = (row_idx * TB) / E_2; + int row_start = (row_idx * TB) % E_2; + + if (dim_start < E_1 && col_idx < K) { + + int offset = IndexToOffset::get(dim_start * D + row_start, index_info); + int idx1 = __ldg(index_info.data + offset), idx2; + + mp_t val = static_cast(x_data[K * (dim_start * D + row_start) + col_idx]); + +#pragma unroll + for (int i = 1; i < TB; i++) { + if (row_start + i >= D) + break; + + idx2 = __ldg(index_info.data + offset + + i * index_info.strides[index_info.dims - 1]); + assert(idx1 <= idx2); + if (idx1 == idx2) { + mp_t tmp = static_cast(x_data[K * (dim_start * D + row_start + i) + col_idx]); + // update + if ((reduce_type == MIN && tmp < val) || + (reduce_type == MAX && tmp > val)) + val = tmp; + else if (reduce_type == SUM || reduce_type == MEAN) + val += tmp; + } else { + // atomic_write + switch(reduce_type) { + case MIN: + atomMin(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MAX: + atomMax(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case SUM: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MEAN: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + val = x_data[K * (dim_start * D + row_start + i) + col_idx]; + } + + idx1 = idx2; + } + // atomic_write + switch(reduce_type) { + case MIN: + atomMin(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MAX: + atomMax(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case SUM: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MEAN: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + } +} + + +template +__global__ void segment_coo_arg_kernel(const data_t *x_data, + const TensorInfo index_info, + mp_t *out_data, + index_t *arg_out_data, + size_t E, + size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = IndexToOffset::get(row_idx, index_info); + index_t idx = index_info.data[offset]; + int out_idx = (row_idx / D) * N + idx; + + mp_t val = __ldg(out_data + out_idx); + if (static_cast(x_data[row_idx]) == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void segment_coo_arg_broadcast_kernel(const data_t *x_data, + const TensorInfo index_info, + mp_t *out_data, + index_t *arg_out_data, + size_t E, + size_t K, + size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E && col_idx < K) { + int offset = IndexToOffset::get(row_idx, index_info); + int idx = __ldg(index_info.data + offset); + int out_idx = ((row_idx / D) * N + idx) * K + col_idx; + + mp_t val = __ldg(out_data + out_idx); + if (static_cast(x_data[thread_idx]) == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void post_process_kernel(data_t init_val, + int numel, + data_t* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < numel) { + if (out_data[tid] == init_val) + out_data[tid] = static_cast(0.0); + } +} + +template +__global__ void post_process_mean_kernel(mp_t* count_data, + int numel) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < numel) { + if (count_data[tid] < static_cast(1.0)) + count_data[tid] = static_cast(1.0); + } +} + +std::vector segment_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce) { + CHECK_CUDA(index); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto dim = index_dims.size() - 1; + paddle::Tensor arg_out; + int count_numel; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + } else if (reduce == "mean") { + auto sizes = index.shape(); + sizes[dim] = return_shape[dim]; + arg_out = paddle::zeros(sizes, out.dtype(), index.place()); + count_numel = std::accumulate(sizes.begin(), sizes.end(), 1.0, std::multiplies()); + } + + auto E = index.numel(); + auto E_2 = index.shape()[dim]; + auto E_1 = index.numel() / E_2; + auto K = x.numel() / E; + auto N = out.shape()[dim]; + auto avg_len = (float)E_2 / (float)N; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "segment_coo_cuda_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + paddle::Tensor out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out_mp = paddle::experimental::cast(out, paddle::DataType::FLOAT32); + } else { + out_mp = out; + } + + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out_mp, std::numeric_limits::max()); + else if (reduce == "max") + paddle::experimental::fill_(out_mp, std::numeric_limits::lowest()); + else if (reduce == "sum" || reduce == "mean") + paddle::experimental::fill_(out_mp, static_cast(0)); + } + + const data_t* x_data = x.data(); + MPType* out_data = out_mp.data(); + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + if (K == 1) + segment_coo_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, N); + + else if (avg_len <= 8) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 16) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 32) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + + if (reduce == "min" || reduce == "max") { + int* arg_out_data = arg_out.data(); + if (K == 1) + segment_coo_arg_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, K, N); + } + + if (reduce == "mean") { + paddle::Tensor arg_out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out_mp = paddle::experimental::cast(arg_out, paddle::DataType::FLOAT32); + } else { + arg_out_mp = arg_out; + } + auto count_data = arg_out_mp.data(); + segment_coo_cuda_forward_kernel + <<>>( + nullptr, index_info, reduce2REDUCE.at(reduce), + count_data, E, N); + post_process_mean_kernel + <<>>( + count_data, count_numel); + paddle::Tensor count = arg_out_mp; + for (int i = dim + 1; i < return_shape.size(); i++) { + count = paddle::experimental::unsqueeze(arg_out_mp, {-1}); + } + if (is_floating_point(out.dtype())) + paddle::experimental::divide_(out_mp, count); + else + paddle::experimental::floor_divide_(out_mp, count); + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out = paddle::experimental::cast(arg_out_mp, x.dtype()); + } + } + + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + if (K == 1) + segment_coo_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, N); + + else if (avg_len <= 8) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 16) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 32) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + + if (reduce == "min" || reduce == "max") { + int64_t* arg_out_data = arg_out.data(); + if (K == 1) + segment_coo_arg_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, K, N); + } + + if (reduce == "mean") { + paddle::Tensor arg_out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out_mp = paddle::experimental::cast(arg_out, paddle::DataType::FLOAT32); + } else { + arg_out_mp = arg_out; + } + auto count_data = arg_out_mp.data(); + segment_coo_cuda_forward_kernel + <<>>( + nullptr, index_info, reduce2REDUCE.at(reduce), + count_data, E, N); + post_process_mean_kernel + <<>>( + count_data, count_numel); + paddle::Tensor count = arg_out_mp; + for (int i = dim + 1; i < return_shape.size(); i++) { + count = paddle::experimental::unsqueeze(arg_out_mp, {-1}); + } + if (is_floating_point(out.dtype())) + paddle::experimental::divide_(out_mp, count); + else + paddle::experimental::floor_divide_(out_mp, count); + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out = paddle::experimental::cast(arg_out_mp, x.dtype()); + } + } + + break; + } + default: + PD_THROW( + "function segment_coo_cuda_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out = paddle::experimental::cast(out_mp, x.dtype()); + } + + if (!init) { + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + post_process_kernel + <<>>( + init_val, out_numel, out.data() + ); + } + + })); + + return {out, arg_out}; +} + + +template +__global__ void +gather_coo_kernel(const mp_t *src_data, + const TensorInfo index_info, + data_t *out_data, size_t E, size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (row_idx < E) { + int offset = IndexToOffset::get( + row_idx, index_info); + int row = index_info.data[offset]; + + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N; + mp_t val = __ldg(src_data + offset + row); + + out_data[row_idx] = static_cast(val); + } +} + +template +__global__ void gather_coo_broadcast_kernel( + const mp_t *src_data, + const TensorInfo index_info, + data_t *out_data, size_t E, size_t K, size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + + if (thread_idx < E * K) { + int offset = IndexToOffset::get( + row_idx, index_info); + int row = index_info.data[offset]; + + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K; + mp_t val = __ldg(src_data + offset + K * row + col_idx); + + out_data[thread_idx] = static_cast(val); + } +} + + +std::vector gather_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape) { + CHECK_CUDA(index); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] == index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto E = index.numel(); + auto K = out.numel() / E; + auto N = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "gather_coo_cuda_forward_kernel", ([&] { + using MPType = typename MPTypeTrait::Type; + paddle::Tensor x_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) + x_mp = paddle::experimental::cast(x, paddle::DataType::FLOAT32); + else + x_mp = x; + + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + if (K == 1) + gather_coo_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, N); + else + gather_coo_broadcast_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, K, N); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + if (K == 1) + gather_coo_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, N); + else + gather_coo_broadcast_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, K, N); + break; + } + default: + PD_THROW( + "function gather_coo_cuda_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out}; +} \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/segment_csr.cc b/jointContribution/paddle_scatter/csrc/segment_csr.cc new file mode 100644 index 000000000..6a602185e --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_csr.cc @@ -0,0 +1,409 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void segment_csr_cpu_forward_kernel(const data_t* x_data, + const std::vector& indptr_shape, + const TensorInfo& indptr_info, + const std::string& reduce, + int stride, + int64_t dim, + int N, + int K, + int E, + data_t* out_data, + index_t* arg_out_data) { + using MPType = typename MPTypeTrait::Type; + std::vector args(K); + std::vector vals(K); + index_t row_start, row_end; + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + offset = (n / (indptr_shape[dim] - 1)) * E * K; + for (auto k = 0; k < K; k++) { + // init + if (reduce == "min") + vals[k] = static_cast(std::numeric_limits::max()); + else if (reduce == "max") + vals[k] = static_cast(std::numeric_limits::lowest()); + else if (reduce == "sum" || reduce == "mean") + vals[k] = static_cast(0); + } + + for (auto e = row_start; e < row_end; e++) { + for (auto k = 0; k < K; k++) { + // update + auto cmp = static_cast(x_data[offset + e * K + k]); + if ((reduce == "min" && cmp < vals[k]) || + (reduce == "max" && cmp > vals[k])) { + vals[k] = cmp; + args[k] = e; + } else if (reduce == "sum" || reduce == "mean") { + vals[k] += cmp; + } + } + } + + for (auto k = 0; k < K; k++) { + // write + auto idx = n * K + k; + auto count = row_end - row_start; + if (reduce == "sum") { + out_data[idx] = static_cast(vals[k]); + } else if (reduce == "mean") { + out_data[idx] = static_cast(vals[k] / static_cast(count > 0 ? count : 1)); + } else if (reduce == "min" || reduce == "max") { + if (count > 0) { + out_data[idx] = static_cast(vals[k]); + arg_out_data[idx] = args[k]; + } else { + out_data[idx] = static_cast(0); + } + } + } + } +} + +std::vector segment_csr_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CPU(indptr); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], indptr.dtype(), indptr.place()); + } + + auto N = return_shape[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = out.numel() / N; + auto E = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "segment_csr_cpu_forward_kernel", ([&] { + + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + segment_csr_cpu_forward_kernel( + x.data(), indptr_dims, indptr_info, reduce, + stride, dim, N, K, E, out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + segment_csr_cpu_forward_kernel( + x.data(), indptr_dims, indptr_info, reduce, + stride, dim, N, K, E, out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr); + break; + } + default: + PD_THROW( + "function segment_csr_cpu_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + +template +void gather_csr_cpu_forward_kernel(const data_t* x_data, + const TensorInfo& indptr_info, + const std::vector& indptr_shape, + int64_t dim, + int stride, + int N, + int K, + int E, + data_t* out_data) { + std::vector vals(K); + int64_t row_start, row_end; + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + for (auto k = 0; k < K; k++) + vals[k] = x_data[n * K + k]; + + offset = (n / (indptr_shape[dim] - 1)) * E * K; + for (auto e = row_start; e < row_end; e++) + for (auto k = 0; k < K; k++) + out_data[offset + e * K + k] = vals[k]; + } +} + +std::vector gather_csr_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape) { + CHECK_CPU(indptr); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto N = x_dims[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = x.numel() / N; + auto E = return_shape[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "gather_csr_cpu_forward_kernel", ([&] { + + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + gather_csr_cpu_forward_kernel( + x.data(), indptr_info, indptr_dims, dim, + stride, N, K, E, out.data()); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + gather_csr_cpu_forward_kernel( + x.data(), indptr_info, indptr_dims, dim, + stride, N, K, E, out.data()); + break; + } + default: + PD_THROW( + "function gather_csr_cpu_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out}; +} + +std::vector segment_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce); + +std::vector SegmentCsrForward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + if (x.is_cpu()) { + return segment_csr_cpu_forward(x, indptr, init, return_shape, reduce); + } else if (x.is_gpu()) { + return segment_csr_cuda_forward(x, indptr, init, return_shape, reduce); + } else { + PD_THROW("Unsupported device type for forward function of custom segment_csr operator."); + } +} + +std::vector gather_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape); + +std::vector GatherCsrForward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape) { + if (x.is_cpu()) { + return gather_csr_cpu_forward(x, indptr, init, return_shape); + } else if (x.is_gpu()) { + return gather_csr_cuda_forward(x, indptr, init, return_shape); + } else { + PD_THROW("Unsupported device type for forward function of custom gather_csr operator."); + } +} + +std::vector SegmentCsrBackward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& arg_out, + const paddle::Tensor& grad_out, + std::string reduce) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom segment_csr operator."); + } + if (reduce == "min" || reduce == "max") { + int64_t dim = indptr.shape().size() - 1; + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out.get(), grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; + } else if (reduce == "mean") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + int64_t dim = indptr.shape().size() - 1; + if (grad_x.numel() > 0) { + grad_x = GatherCsrForward(grad_out, indptr, paddle::optional(grad_x), grad_x.shape())[0]; + auto indptr1 = paddle::experimental::slice(indptr, {dim}, {0}, {indptr.shape()[dim] - 1}, {1}, {}); + auto indptr2 = paddle::experimental::slice(indptr, {dim}, {1}, {indptr.shape()[dim]}, {1}, {}); + auto count = paddle::experimental::cast(indptr2 - indptr1, grad_x.dtype()); + auto sizes = count.shape(); + sizes[dim] = grad_x.shape()[dim]; + // sizes[dim] = *indptr.flatten()[-1].data(); + count = GatherCsrForward(count, indptr, paddle::optional(paddle::none), sizes)[0]; + for (auto i = 0; i < grad_out.shape().size() - indptr.shape().size(); i++) + paddle::experimental::unsqueeze_(count, {-1}); + paddle::experimental::divide_(grad_x, count); + } + return {grad_x}; + } else if (reduce == "sum") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + paddle::Tensor grad_in = GatherCsrForward(grad_out, indptr, paddle::optional(grad_x), grad_x.shape())[0]; + return {grad_in}; + } +} + +std::vector GatherCsrBackward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::Tensor& grad_out) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom gather_csr operator."); + } + auto x_shape = x.shape(); + auto grad_x = paddle::empty(x_shape, x.dtype(), x.place()); + paddle::Tensor grad_in = SegmentCsrForward(grad_out, indptr, paddle::optional(grad_x), grad_x.shape(), "sum")[0]; + return {grad_in}; +} + +std::vector> SegmentCsrFWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const paddle::optional>& init_shape, + std::vector return_shape, + std::string reduce) { + return {return_shape, return_shape}; +} + +std::vector SegmentCsrFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::optional& init_dtype) { + return {x_dtype, indptr_dtype}; +} + +std::vector> SegmentCsrBWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const paddle::optional>& arg_out_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector SegmentCsrBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::optional& arg_out_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +std::vector> GatherCsrFWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const paddle::optional>& init_shape, + std::vector return_shape) { + return {return_shape}; +} + +std::vector GatherCsrFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::optional& init_dtype) { + return {x_dtype}; +} + +std::vector> GatherCsrBWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector GatherCsrBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP(custom_segment_csr) + .Inputs({"X", "Indptr", paddle::Optional("Init")}) + .Outputs({"Out", paddle::Optional("ArgOut")}) + .Attrs({"return_shape: std::vector", + "reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCsrForward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCsrFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCsrFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_segment_csr) + .Inputs({"X", "Indptr", paddle::Optional("ArgOut"), paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCsrBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCsrBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCsrBWInferDtype)); + +PD_BUILD_OP(custom_gather_csr) + .Inputs({"X", "Indptr", paddle::Optional("Init")}) + .Outputs({"Out"}) + .Attrs({"return_shape: std::vector"}) + .SetKernelFn(PD_KERNEL(GatherCsrForward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCsrFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCsrFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_gather_csr) + .Inputs({"X", "Indptr", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(GatherCsrBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCsrBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCsrBWInferDtype)); \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/segment_csr.cu b/jointContribution/paddle_scatter/csrc/segment_csr.cu new file mode 100644 index 000000000..7c5e8a127 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_csr.cu @@ -0,0 +1,382 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + + +enum ReductionType { MIN, MAX, SUM, MEAN }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX}, {"sum", SUM}, {"mean", MEAN} +}; + +template +__global__ void segment_csr_kernel(const data_t *x_data, + const TensorInfo indptr_info, + ReductionType reduce_type, + data_t *out_data, + index_t *arg_out_data, + size_t N, + size_t E) { + + // Each warp processes exactly `32/TB` rows and aggregates all row values + // via a parallel reduction. + + using MPType = typename MPTypeTrait::Type; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx & (TB - 1); + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + index_t row_start = __ldg(indptr_info.data + offset); + index_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + // init + data_t val; + if (reduce_type == MIN) + val = static_cast(std::numeric_limits::max()); + else if (reduce_type == MAX) + val = static_cast(std::numeric_limits::lowest()); + else if (reduce_type == SUM || reduce_type == MEAN) + val = static_cast(0); + + index_t arg, arg_tmp; + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (index_t x_idx = row_start + lane_idx; x_idx < row_end; + x_idx += TB) { + // update + auto cmp = x_data[offset + x_idx]; + if ((reduce_type == MIN && cmp < val) || + (reduce_type == MAX && cmp > val)) { + val = cmp; + arg = x_idx; + } else if (reduce_type == SUM || reduce_type == MEAN) { + val += cmp; + } + } + +#pragma unroll + for (int i = TB / 2; i > 0; i /= 2) { + // Parallel reduction inside a single warp. + if (reduce_type == MIN || reduce_type == MAX) + arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i); + // update + MPType cmp = SHFL_DOWN_SYNC(FULL_MASK, static_cast(val), i); + if ((reduce_type == MIN && cmp < static_cast(val)) || + (reduce_type == MAX && cmp > static_cast(val))) { + val = static_cast(cmp); + arg = arg_tmp; + } else if (reduce_type == SUM || reduce_type == MEAN) { + val += static_cast(cmp); + } + } + + if (lane_idx == 0) { + // write + auto count = row_end - row_start; + if (reduce_type == SUM) { + out_data[row_idx] = val; + } else if (reduce_type == MEAN) { + out_data[row_idx] = val / static_cast(count > 0 ? count : 1); + } else if (reduce_type == MIN || reduce_type == MAX) { + if (count > 0) { + out_data[row_idx] = val; + arg_out_data[row_idx] = arg; + } else { + out_data[row_idx] = static_cast(0); + } + } + } + } +} + +template +__global__ void segment_csr_broadcast_min_max_kernel(const data_t *x_data, + const TensorInfo indptr_info, + ReductionType reduce_type, + data_t *out_data, + index_t *arg_out_data, + size_t N, + size_t K, + size_t E) { + + // Each thread processes exactly one row. It turned out that is more + // efficient than using shared memory due to avoiding synchronization + // barriers. + + using MPType = typename MPTypeTrait::Type; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + index_t row_start = __ldg(indptr_info.data + offset); + index_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + // init + data_t val; + if (reduce_type == MIN) + val = static_cast(std::numeric_limits::max()); + else if (reduce_type == MAX) + val = static_cast(std::numeric_limits::lowest()); + else if (reduce_type == SUM || reduce_type == MEAN) + val = static_cast(0); + index_t arg; + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (index_t x_idx = row_start; x_idx < row_end; x_idx++) { + // update + auto cmp = x_data[offset + K * x_idx + lane_idx]; + if ((reduce_type == MIN && cmp < val) || + (reduce_type == MAX && cmp > val)) { + val = cmp; + arg = x_idx; + } else if (reduce_type == SUM || reduce_type == MEAN) { + val += cmp; + } + } + + // write + auto count = row_end - row_start; + if (reduce_type == SUM) { + out_data[thread_idx] = val; + } else if (reduce_type == MEAN) { + out_data[thread_idx] = val / static_cast(count > 0 ? count : 1); + } else if (reduce_type == MIN || reduce_type == MAX) { + if (count > 0) { + out_data[thread_idx] = val; + arg_out_data[thread_idx] = arg; + } else { + out_data[thread_idx] = static_cast(0); + } + } + } +} + +std::vector segment_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CUDA(indptr); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], indptr.dtype(), indptr.place()); + } + + auto N = return_shape[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = out.numel() / N; + auto E = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "segment_csr_cuda_forward_kernel", ([&] { + + const data_t* x_data = x.data(); + data_t* out_data = out.data(); + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int* arg_out_data = (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr; + if (K == 1) { + segment_csr_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, K, E); + } + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int64_t* arg_out_data = (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr; + if (K == 1) { + segment_csr_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, K, E); + } + break; + } + default: + PD_THROW( + "function segment_csr_cuda_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + +template +__global__ void +gather_csr_kernel(const mp_t *src_data, + const TensorInfo indptr_info, + data_t *out_data, size_t N, size_t E) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx % TB; + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int row_start = __ldg(indptr_info.data + offset); + int row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + mp_t val = __ldg(src_data + row_idx); + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { + out_data[offset + out_idx] = static_cast(val); // "Mostly" coalesced. + } + } +} + +template +__global__ void gather_csr_broadcast_kernel( + const mp_t *src_data, + const TensorInfo indptr_info, + data_t *out_data, size_t N, size_t K, size_t E) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int row_start = __ldg(indptr_info.data + offset); + int row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + mp_t val = src_data[thread_idx]; // Coalesced. + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (int out_idx = row_start; out_idx < row_end; out_idx++) { + out_data[offset + K * out_idx + lane_idx] = static_cast(val); // "Mostly" coalesced. + } + } +} + + +std::vector gather_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape) { + CHECK_CUDA(indptr); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto N = x_dims[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = x.numel() / N; + auto E = return_shape[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "gather_csr_cuda_forward_kernel", ([&] { + using MPType = typename MPTypeTrait::Type; + paddle::Tensor x_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) + x_mp = paddle::experimental::cast(x, paddle::DataType::FLOAT32); + else + x_mp = x; + const MPType* x_data = x_mp.data(); + data_t* out_data = out.data(); + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + if (K == 1) + gather_csr_kernel + <<>>( + x_data, indptr_info, out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>( + x_data, indptr_info, out_data, N, K, E); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + if (K == 1) + gather_csr_kernel + <<>>( + x_data, indptr_info, out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>( + x_data, indptr_info, out_data, N, K, E); + break; + } + default: + PD_THROW( + "function gather_csr_cuda_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out}; +} \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/utils.cuh b/jointContribution/paddle_scatter/csrc/utils.cuh index d5759541b..f5371e043 100644 --- a/jointContribution/paddle_scatter/csrc/utils.cuh +++ b/jointContribution/paddle_scatter/csrc/utils.cuh @@ -85,4 +85,4 @@ class MPTypeTrait { #else #define SHFL_UP_SYNC __shfl_up_sync #define SHFL_DOWN_SYNC __shfl_down_sync -#endif \ No newline at end of file +#endif diff --git a/jointContribution/paddle_scatter/scatter.py b/jointContribution/paddle_scatter/scatter.py index c513ecf52..7d9b5d231 100644 --- a/jointContribution/paddle_scatter/scatter.py +++ b/jointContribution/paddle_scatter/scatter.py @@ -2,16 +2,7 @@ from typing import Tuple import paddle -from paddle import assign -from paddle import divide -from paddle import floor_divide -from paddle import full -from paddle import full_like -from paddle import ones -from paddle import put_along_axis -from paddle import where -from paddle import zeros -from paddle_scatter_min_max_ops import custom_scatter_min_max +import paddle_scatter_ops from .utils import broadcast @@ -51,15 +42,19 @@ def scatter_sum( size[dim] = 0 else: size[dim] = int(index.max()) + 1 - arr = zeros(size, dtype=src.dtype) + arr = paddle.zeros(size, dtype=src.dtype) if src.numel() == 0: return arr - return put_along_axis(arr, indices=index, values=src, axis=dim, reduce="add") + return paddle.put_along_axis( + arr, indices=index, values=src, axis=dim, reduce="add" + ) else: if src.numel() == 0: return out - result = put_along_axis(out, indices=index, values=src, axis=dim, reduce="add") - assign(result, out) + result = paddle.put_along_axis( + out, indices=index, values=src, axis=dim, reduce="add" + ) + paddle.assign(result, out) return out @@ -127,15 +122,19 @@ def scatter_mul( size[dim] = 0 else: size[dim] = int(index.max()) + 1 - arr = ones(size, dtype=src.dtype) + arr = paddle.ones(size, dtype=src.dtype) if src.numel() == 0: return arr - return put_along_axis(arr, indices=index, values=src, axis=dim, reduce="mul") + return paddle.put_along_axis( + arr, indices=index, values=src, axis=dim, reduce="mul" + ) else: if src.numel() == 0: return out - result = put_along_axis(out, indices=index, values=src, axis=dim, reduce="mul") - assign(result, out) + result = paddle.put_along_axis( + out, indices=index, values=src, axis=dim, reduce="mul" + ) + paddle.assign(result, out) return out @@ -174,18 +173,18 @@ def scatter_mean( if index.dim() <= index_dim: index_dim = index.dim() - 1 - ones_tensor = ones(index.shape, dtype=src.dtype) + ones_tensor = paddle.ones(index.shape, dtype=src.dtype) tmp = scatter_sum(ones_tensor, index, index_dim, None, dim_size) - count = where(tmp < 1, full_like(tmp, 1), tmp, name="where") + count = paddle.where(tmp < 1, paddle.full_like(tmp, 1), tmp, name="where") count = broadcast(count, sums, dim) if sums.is_floating_point(): - result = divide(sums, count) + result = paddle.divide(sums, count) else: - result = floor_divide(sums, count) + result = paddle.floor_divide(sums, count) if out is None: return result else: - assign(result, out) + paddle.assign(result, out) return out @@ -230,20 +229,22 @@ def scatter_min( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_scatter_min_max(src, index, None, size, "min", dim) + return paddle_scatter_ops.custom_scatter_min_max( + src, index, None, size, "min", dim + ) else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_scatter_min_max( + result, arg_result = paddle_scatter_ops.custom_scatter_min_max( src, index, out, out.shape, "min", dim ) - assign(result, out) + paddle.assign(result, out) return out, arg_result @@ -288,20 +289,22 @@ def scatter_max( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_scatter_min_max(src, index, None, size, "max", dim) + return paddle_scatter_ops.custom_scatter_min_max( + src, index, None, size, "max", dim + ) else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_scatter_min_max( + result, arg_result = paddle_scatter_ops.custom_scatter_min_max( src, index, out, out.shape, "max", dim ) - assign(result, out) + paddle.assign(result, out) return out, arg_result diff --git a/jointContribution/paddle_scatter/segment_coo.py b/jointContribution/paddle_scatter/segment_coo.py index 75af63aff..91ea2e5aa 100644 --- a/jointContribution/paddle_scatter/segment_coo.py +++ b/jointContribution/paddle_scatter/segment_coo.py @@ -2,19 +2,7 @@ from typing import Tuple import paddle -from paddle import assign -from paddle import full -from paddle import slice -from paddle import take_along_axis -from paddle import to_tensor -from paddle import zeros -from paddle.geometric import segment_mean -from paddle.geometric import segment_sum -from paddle.nn.functional import pad -from paddle_scatter_min_max_ops import custom_segment_coo_min_max - -from .utils import broadcast -from .utils import transform_3d +import paddle_scatter_ops def segment_sum_coo( @@ -41,46 +29,42 @@ def segment_sum_coo( Returns: paddle.Tensor, the reduced tensor by sum reduction method. """ + src_shape = src.shape index_shape = index.shape dim = len(index_shape) - 1 - index = broadcast(index, src, dim) + # broadcast indptr to src + index_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + index = index.reshape(index_shape) + else: + index = index.expand(index_shape) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + tmp = index.index_select( + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + size[dim] = int(tmp) + 1 + if out is None: - out_size = src.shape - if dim_size is not None: - out_size[dim] = dim_size - elif index.numel() == 0: - out_size[dim] = 0 - else: - tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim - ).squeeze(dim) - tmp = tmp.max() if tmp.numel() > 1 else tmp - out_size[dim] = int(tmp) + 1 if src.numel() == 0: - return zeros(out_size, dtype=src.dtype) + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "sum")[0] else: if src.numel() == 0: return out - out_size = out.shape - init = out.clone() - tmp = zeros(out_size, dtype=src.dtype) - src_fatten = transform_3d(src, dim) - index_flatten = transform_3d(index, dim) - out_flatten = transform_3d(tmp, dim) - for i in range(src_fatten.shape[0]): - for j in range(src_fatten.shape[-1]): - result = segment_sum(src_fatten[i, :, j], index_flatten[i, :, j]) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - res = out_flatten.reshape(out_size) - if out is None: - return res - else: - res = res + init - assign(res, out) - return out + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result = paddle_scatter_ops.custom_segment_coo( + src, index, out, out.shape, "sum" + )[0] + paddle.assign(result, out) + return out def segment_add_coo( @@ -134,38 +118,42 @@ def segment_mean_coo( Returns: paddle.Tensor, the reduced tensor by mean reduction method. """ + src_shape = src.shape index_shape = index.shape dim = len(index_shape) - 1 - index = broadcast(index, src, dim) + # broadcast indptr to src + index_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + index = index.reshape(index_shape) + else: + index = index.expand(index_shape) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + tmp = index.index_select( + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + size[dim] = int(tmp) + 1 + if out is None: - out_size = src.shape - if dim_size is not None: - out_size[dim] = dim_size - elif index.numel() == 0: - out_size[dim] = 0 - else: - tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim - ).squeeze(dim) - tmp = tmp.max() if tmp.numel() > 1 else tmp - out_size[dim] = int(tmp) + 1 - out = zeros(out_size, dtype=src.dtype) + if src.numel() == 0: + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "mean")[0] else: - out_size = out.shape - if src.numel() == 0: + if src.numel() == 0: + return out + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result = paddle_scatter_ops.custom_segment_coo( + src, index, out, out.shape, "mean" + )[0] + paddle.assign(result, out) return out - src_fatten = transform_3d(src, dim) - index_flatten = transform_3d(index, dim) - out_flatten = transform_3d(out, dim) - for i in range(src_fatten.shape[0]): - for j in range(src_fatten.shape[-1]): - result = segment_mean(src_fatten[i, :, j], index_flatten[i, :, j]) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - assign(out_flatten.reshape(out_size), out) - return out def segment_min_coo( @@ -208,7 +196,7 @@ def segment_min_coo( size[dim] = 0 else: tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim ).squeeze(dim) tmp = tmp.max() if tmp.numel() > 1 else tmp size[dim] = int(tmp) + 1 @@ -216,20 +204,20 @@ def segment_min_coo( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_segment_coo_min_max(src, index, None, size, "min") + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "min") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_segment_coo_min_max( + result, arg_result = paddle_scatter_ops.custom_segment_coo( src, index, out, out.shape, "min" ) - assign(result, out) + paddle.assign(result, out) return out, arg_result @@ -273,7 +261,7 @@ def segment_max_coo( size[dim] = 0 else: tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim ).squeeze(dim) tmp = tmp.max() if tmp.numel() > 1 else tmp size[dim] = int(tmp) + 1 @@ -281,20 +269,20 @@ def segment_max_coo( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_segment_coo_min_max(src, index, None, size, "max") + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "max") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_segment_coo_min_max( + result, arg_result = paddle_scatter_ops.custom_segment_coo( src, index, out, out.shape, "max" ) - assign(result, out) + paddle.assign(result, out) return out, arg_result @@ -429,33 +417,20 @@ def gather_coo( dim = len(index_shape) - 1 src_shape = src.shape - # broadcast index's dimension to the same as src's - for _ in range(len(src_shape) - len(index_shape)): - index = index.unsqueeze(-1) - new_index_shape = src_shape.copy() - new_index_shape[dim] = index_shape[dim] - if src.numel() == 0: - index = index.reshape(new_index_shape) - else: - index = index.expand(new_index_shape) - if out is None: out_size = src_shape if index.numel() == 0: out_size[dim] = 0 - else: - out_size[dim] = index.shape[dim] - out = zeros(out_size, dtype=src.dtype) + return paddle.zeros(out_size, dtype=src.dtype) + out_size[dim] = index.shape[dim] + return paddle_scatter_ops.custom_gather_coo(src, index, None, out_size) else: + if src.numel() == 0: + return out out_size = out.shape - if src.numel() == 0: + for i in range(len(out_size)): + if i != dim: + assert src_shape[i] == out_size[i] + result = paddle_scatter_ops.custom_gather_coo(src, index, out, out_size) + paddle.assign(result, out) return out - - result = take_along_axis(src, index, dim, broadcast=False) - if out_size[dim] > result.shape[dim]: - padding = [0] * 2 * len(out_size) - padding[2 * dim + 1] = out_size[dim] - result.shape[dim] - result = pad(result, padding, mode="constant", value=0) - elif out_size[dim] < result.shape[dim]: - result = slice(result, [dim], 0, out_size[dim]) - return assign(result, out) diff --git a/jointContribution/paddle_scatter/segment_csr.py b/jointContribution/paddle_scatter/segment_csr.py index 6ecfbeb92..8d0e7eec8 100644 --- a/jointContribution/paddle_scatter/segment_csr.py +++ b/jointContribution/paddle_scatter/segment_csr.py @@ -2,17 +2,7 @@ from typing import Tuple import paddle -from paddle import arange -from paddle import assign -from paddle import full -from paddle import repeat_interleave -from paddle import zeros -from paddle.geometric import segment_mean -from paddle.geometric import segment_sum -from paddle_scatter_min_max_ops import custom_segment_csr_min_max - -from .utils import transform_2d -from .utils import transform_3d +import paddle_scatter_ops def segment_sum_csr( @@ -48,46 +38,19 @@ def segment_sum_csr( else: indptr = indptr.expand(indptr_shape) - num_seg = indptr_shape[dim] - 1 if out is None: - out_size = src_shape - if indptr.numel() == 0: - out_size[dim] = 0 - else: - out_size[dim] = num_seg + size = src.shape + size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: - return zeros(out_size, dtype=src.dtype) + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "sum")[0] else: - assert ( - out.shape[dim] == num_seg - ), "The (size of indptr at last dimension) must be\ - equal to the (size of out at the same dimension) + 1." if src.numel() == 0: return out - out_size = out.shape - tmp = zeros(out_size, dtype=src.dtype) - - repeats = indptr.diff(n=1, axis=dim) - assert ( - repeats.sum(axis=dim) == src.shape[dim] - ).all(), "The length of specified index by indptr shoud be\ - equal to the size of src at last dimension of indptr." - src_flatten = transform_3d(src, dim) - out_flatten = transform_3d(tmp, dim) - repeats_flatten = transform_2d(repeats, dim) - src_dim_indices = arange(num_seg) - for i in range(src_flatten.shape[0]): - for j in range(src_flatten.shape[-1]): - belongs_to = repeat_interleave(src_dim_indices, repeats_flatten[i], 0) - result = segment_sum(src_flatten[i, :, j], belongs_to) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - if out is None: - return out_flatten.reshape(out_size) - else: - assign(out_flatten.reshape(out_size), out) + result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "sum" + )[0] + paddle.assign(result, out) return out @@ -150,46 +113,19 @@ def segment_mean_csr( else: indptr = indptr.expand(indptr_shape) - num_seg = indptr_shape[dim] - 1 if out is None: - out_size = src_shape - if indptr.numel() == 0: - out_size[dim] = 0 - else: - out_size[dim] = num_seg + size = src.shape + size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: - return zeros(out_size, dtype=src.dtype) + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "mean")[0] else: - assert ( - out.shape[dim] == num_seg - ), "The (size of indptr at last dimension) must be\ - equal to the (size of out at the same dimension) + 1." if src.numel() == 0: return out - out_size = out.shape - tmp = zeros(out_size, dtype=src.dtype) - - repeats = indptr.diff(n=1, axis=dim) - assert ( - repeats.sum(axis=dim) == src.shape[dim] - ).all(), "The length of specified index by indptr shoud be\ - equal to the size of src at last dimension of indptr." - src_flatten = transform_3d(src, dim) - out_flatten = transform_3d(tmp, dim) - repeats_flatten = transform_2d(repeats, dim) - src_dim_indices = arange(num_seg) - for i in range(src_flatten.shape[0]): - for j in range(src_flatten.shape[-1]): - belongs_to = repeat_interleave(src_dim_indices, repeats_flatten[i], 0) - result = segment_mean(src_flatten[i, :, j], belongs_to) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - if out is None: - return out_flatten.reshape(out_size) - else: - assign(out_flatten.reshape(out_size), out) + result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "mean" + )[0] + paddle.assign(result, out) return out @@ -231,15 +167,17 @@ def segment_min_csr( size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], indptr.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], indptr.dtype), ) - return custom_segment_csr_min_max(src, indptr, size, "min") + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "min") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], indptr.dtype)) - result, arg_result = custom_segment_csr_min_max(src, indptr, out.shape, "min") - assign(result, out) + return (out, paddle.full(size, src.shape[dim], indptr.dtype)) + result, arg_result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "min" + ) + paddle.assign(result, out) return out, arg_result @@ -276,20 +214,25 @@ def segment_max_csr( else: indptr = indptr.expand(indptr_shape) + size = src.shape if out is None: - size = src.shape size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], indptr.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], indptr.dtype), ) - return custom_segment_csr_min_max(src, indptr, size, "max") + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "max") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], indptr.dtype)) - result, arg_result = custom_segment_csr_min_max(src, indptr, out.shape, "max") - assign(result, out) + return (out, paddle.full(size, src.shape[dim], indptr.dtype)) + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result, arg_result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "max" + ) + paddle.assign(result, out) return out, arg_result @@ -429,26 +372,16 @@ def gather_csr( out_size = src_shape if indptr.numel() == 0: out_size[dim] = 0 - else: - # refer to the original design in source cpp code - out_size[dim] = indptr.flatten()[-1] - out = zeros(out_size, dtype=src.dtype) + return paddle.zeros(out_size, dtype=src.dtype) + out_size[dim] = indptr.flatten()[-1] + return paddle_scatter_ops.custom_gather_csr(src, indptr, None, out_size) else: + if src.numel() == 0: + return out out_size = out.shape - if src.numel() == 0: + for i in range(len(out_size)): + if i != dim: + assert src_shape[i] == out_size[i] + result = paddle_scatter_ops.custom_gather_csr(src, indptr, out, out_size) + paddle.assign(result, out) return out - - repeats = indptr.diff(n=1, axis=dim) - src_flatten = transform_3d(src, dim) - out_flatten = transform_3d(out, dim) - repeats_flatten = transform_2d(repeats, dim) - for i in range(src_flatten.shape[0]): - for j in range(src_flatten.shape[-1]): - result = repeat_interleave(src_flatten[i, :, j], repeats_flatten[i], 0) - repeat_sum = repeats_flatten[i].sum() - if out_size[dim] >= repeat_sum: - out_flatten[i, :repeat_sum, j] = result[:repeat_sum] - else: - out_flatten[i, :, j] = result[: out_size[dim]] - assign(out_flatten.reshape(out_size), out) - return out diff --git a/jointContribution/paddle_scatter/setup.py b/jointContribution/paddle_scatter/setup.py index d78a28133..908c0b44e 100644 --- a/jointContribution/paddle_scatter/setup.py +++ b/jointContribution/paddle_scatter/setup.py @@ -32,10 +32,10 @@ def get_extensions(): setup( - name="paddle_scatter_min_max_ops", + name="paddle_scatter_ops", version="1.0", author="NKNaN", url="https://github.com/PaddlePaddle/PaddleScience/jointContribution/paddle_scatter", - description="Paddle extension of scatter and segment operators with min and max reduction methods", + description="Paddle extension of scatter and segment operators with min and max reduction methods, originally from https://github.com/rusty1s/pytorch_scatter", ext_modules=get_extensions(), ) diff --git a/jointContribution/paddle_scatter/testing.py b/jointContribution/paddle_scatter/testing.py index dfe23294c..1126fb8a6 100644 --- a/jointContribution/paddle_scatter/testing.py +++ b/jointContribution/paddle_scatter/testing.py @@ -4,13 +4,8 @@ reductions = ["sum", "add", "mean", "min", "max"] -dtypes = [ - # paddle.float16, paddle.bfloat16, - paddle.float32, - paddle.float64, - paddle.int32, - paddle.int64, -] +dtypes = [paddle.float32, paddle.float64, paddle.int32, paddle.int64] +dtypes_half = [paddle.float16, paddle.bfloat16] ind_dtypes = [paddle.int32, paddle.int64] grad_dtypes = [paddle.float32, paddle.float64] diff --git a/jointContribution/paddle_scatter/tests/composite/test_softmax.py b/jointContribution/paddle_scatter/tests/composite/test_softmax.py index 20c2c6bb8..28ae9410c 100644 --- a/jointContribution/paddle_scatter/tests/composite/test_softmax.py +++ b/jointContribution/paddle_scatter/tests/composite/test_softmax.py @@ -39,7 +39,7 @@ def test_log_softmax(): paddle.to_tensor([7], dtype=paddle.float32), axis=-1 ) out4 = paddle.nn.functional.log_softmax( - paddle.to_tensor([-1, float("-inf")]), axis=-1 + paddle.to_tensor([-1.0, float("-inf")]), axis=-1 ) expected = paddle.stack( diff --git a/jointContribution/paddle_scatter/tests/test_gather.py b/jointContribution/paddle_scatter/tests/test_gather.py index fbbb4df42..6fd8e40a3 100644 --- a/jointContribution/paddle_scatter/tests/test_gather.py +++ b/jointContribution/paddle_scatter/tests/test_gather.py @@ -5,6 +5,7 @@ from paddle_scatter import gather_coo from paddle_scatter import gather_csr from paddle_scatter.testing import dtypes +from paddle_scatter.testing import dtypes_half from paddle_scatter.testing import ind_dtypes from paddle_scatter.testing import places from paddle_scatter.testing import tensor @@ -72,6 +73,29 @@ def test_forward(test, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,dtype,ind_dtype", product(tests, dtypes_half, ind_dtypes) +) +def test_forward_half(test, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + out = gather_csr(src, indptr) + assert paddle.all(out == expected) + + out = gather_coo(src, index) + assert paddle.all(out == expected) + + @pytest.mark.parametrize("test,place", product(tests, places)) def test_backward(test, place): paddle.set_device(place) @@ -92,6 +116,32 @@ def test_backward(test, place): assert paddle.all(src.grad == exp_grad) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize("test", tests) +def test_backward_half(test): + paddle.set_device("gpu") + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + exp_grad = tensor(test["expected_grad"], paddle.float16) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = gather_csr(src, indptr) + out.backward() + assert paddle.all(src.grad == exp_grad) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = gather_coo(src, index) + out.backward() + assert paddle.all(src.grad == exp_grad) + + @pytest.mark.parametrize( "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) ) @@ -115,6 +165,35 @@ def test_out(test, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,dtype,ind_dtype", product(tests, dtypes_half, ind_dtypes) +) +def test_out_half(test, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + size = src.shape + size[index.dim() - 1] = index.shape[-1] + out = paddle.full(size, -2).astype(dtype) + + gather_csr(src, indptr, out) + assert paddle.all(out == expected) + + out.fill_(-2) + + gather_coo(src, index, out) + assert paddle.all(out == expected) + + @pytest.mark.parametrize( "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) ) @@ -143,3 +222,39 @@ def test_non_contiguous(test, dtype, ind_dtype, place): out = gather_coo(src, index) assert paddle.all(out == expected) + + +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,dtype,ind_dtype", product(tests, dtypes_half, ind_dtypes) +) +def test_non_contiguous_half(test, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + if indptr.dim() > 1: + shape = list(range(indptr.dim())) + shape[0], shape[1] = shape[1], shape[0] + indptr = indptr.transpose(shape).contiguous().transpose(shape) + + out = gather_csr(src, indptr) + assert paddle.all(out == expected) + + out = gather_coo(src, index) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/tests/test_multi_gpu.py b/jointContribution/paddle_scatter/tests/test_multi_gpu.py index e8dd74b6b..c5136d705 100644 --- a/jointContribution/paddle_scatter/tests/test_multi_gpu.py +++ b/jointContribution/paddle_scatter/tests/test_multi_gpu.py @@ -22,8 +22,8 @@ ] -@pytest.mark.skipif(not paddle.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(paddle.cuda.device_count() < 2, reason="No multiple GPUS") +@pytest.mark.skipif(paddle.device.cuda.device_count() == 0, reason="CUDA not available") +@pytest.mark.skipif(paddle.device.cuda.device_count() < 2, reason="No multiple GPUS") @pytest.mark.parametrize("test,reduce,dtype", product(tests, reductions, dtypes)) def test_forward(test, reduce, dtype): paddle.set_device("gpu:1") diff --git a/jointContribution/paddle_scatter/tests/test_scatter.py b/jointContribution/paddle_scatter/tests/test_scatter.py index 51c519222..f15fe3894 100644 --- a/jointContribution/paddle_scatter/tests/test_scatter.py +++ b/jointContribution/paddle_scatter/tests/test_scatter.py @@ -5,6 +5,7 @@ import paddle_scatter import pytest from paddle_scatter.testing import dtypes +from paddle_scatter.testing import dtypes_half from paddle_scatter.testing import ind_dtypes from paddle_scatter.testing import places from paddle_scatter.testing import reductions @@ -221,6 +222,33 @@ def test_forward(test, reduce, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_forward_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + fn = getattr(paddle_scatter, "scatter_" + reduce) + out = fn(src, index, dim) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + print(arg_out) + print(arg_expected) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + @pytest.mark.parametrize("test,reduce,place", product(tests, reductions, places)) def test_backward(test, reduce, place): paddle.set_device(place) @@ -237,6 +265,28 @@ def test_backward(test, reduce, place): ) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize("test,reduce", product(tests, reductions)) +def test_backward_half(test, reduce): + paddle.set_device("gpu") + index = tensor(test["index"], paddle.int64) + dim = test["dim"] + exp_grad = tensor(test[f"{reduce}_grad"], paddle.float16) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = paddle_scatter.scatter(src, index, dim, None, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), @@ -268,11 +318,81 @@ def test_out(test, reduce, dtype, ind_dtype, place): np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_out_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + out = paddle.full_like(expected, -2) + + getattr(paddle_scatter, "scatter_" + reduce)(src, index, dim, out) + + if reduce == "sum" or reduce == "add": + expected = expected - 2 + elif reduce == "mul": + expected = out # We can not really test this here. + elif reduce == "mean": + expected = out # We can not really test this here. + elif reduce == "min": + print(expected) + expected = expected.fill_(-2) + elif reduce == "max": + expected[expected == 0] = -2 + else: + raise ValueError + + np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) + + +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_non_contiguous(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + + out = getattr(paddle_scatter, "scatter_" + reduce)(src, index, dim) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), ) -def test_non_contiguous(test, reduce, dtype, ind_dtype, place): +def test_non_contiguous_half(test, reduce, dtype, ind_dtype, place): paddle.set_device(place) src = tensor(test["src"], dtype) index = tensor(test["index"], ind_dtype) diff --git a/jointContribution/paddle_scatter/tests/test_segment.py b/jointContribution/paddle_scatter/tests/test_segment.py index 4acb788aa..1c3866c45 100644 --- a/jointContribution/paddle_scatter/tests/test_segment.py +++ b/jointContribution/paddle_scatter/tests/test_segment.py @@ -5,6 +5,7 @@ import paddle_scatter import pytest from paddle_scatter.testing import dtypes +from paddle_scatter.testing import dtypes_half from paddle_scatter.testing import ind_dtypes from paddle_scatter.testing import places from paddle_scatter.testing import reductions @@ -200,6 +201,39 @@ def test_forward(test, reduce, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_forward_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + fn = getattr(paddle_scatter, "segment_" + reduce + "_csr") + out = fn(src, indptr) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + fn = getattr(paddle_scatter, "segment_" + reduce + "_coo") + out = fn(src, index) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + @pytest.mark.parametrize("test,reduce,place", product(tests, reductions, places)) def test_backward(test, reduce, place): paddle.set_device(place) @@ -224,6 +258,36 @@ def test_backward(test, reduce, place): ) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize("test,reduce", product(tests, reductions)) +def test_backward_half(test, reduce): + paddle.set_device("gpu") + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + exp_grad = tensor(test[f"{reduce}_grad"], paddle.float16) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = paddle_scatter.segment_csr(src, indptr, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = paddle_scatter.segment_coo(src, index, None, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), @@ -258,6 +322,45 @@ def test_out(test, reduce, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_out_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + out = paddle.full_like(expected, -2) + + getattr(paddle_scatter, "segment_" + reduce + "_csr")(src, indptr, out) + np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) + + out.fill_(-2) + + getattr(paddle_scatter, "segment_" + reduce + "_coo")(src, index, out) + + if reduce == "sum" or reduce == "add": + expected = expected - 2 + elif reduce == "mean": + expected = out # We can not really test this here. + elif reduce == "min": + expected = expected.fill_(-2) + elif reduce == "max": + expected[expected == 0] = -2 + else: + raise ValueError + + assert paddle.all(out == expected) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), @@ -295,3 +398,47 @@ def test_non_contiguous(test, reduce, dtype, ind_dtype, place): arg_expected = tensor(test["arg_" + reduce], ind_dtype) assert paddle.all(arg_out == arg_expected) assert paddle.all(out == expected) + + +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_non_contiguous_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + if indptr.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + indptr = indptr.transpose(shape).contiguous().transpose(shape) + + out = getattr(paddle_scatter, "segment_" + reduce + "_csr")(src, indptr) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + out = getattr(paddle_scatter, "segment_" + reduce + "_coo")(src, index) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/utils.py b/jointContribution/paddle_scatter/utils.py index b2db4874f..7356a011b 100644 --- a/jointContribution/paddle_scatter/utils.py +++ b/jointContribution/paddle_scatter/utils.py @@ -34,40 +34,3 @@ def broadcast(src: paddle.Tensor, other: paddle.Tensor, dim: int) -> paddle.Tens return src.reshape(other.shape) src = src.expand(other.shape) return src - - -def transform_3d(tensor: paddle.Tensor, dim: int) -> paddle.Tensor: - r"""Transform a tensor to [pre, dim, post] 3d dimensions. - - Args: - tensor (paddle.Tensor): The input tensor. - dim (int): The target dimension. - - Returns: - paddle.Tensor, the transformed tensor. - """ - if tensor.dim() == 1: - return tensor.unsqueeze(0).unsqueeze(-1) - elif tensor.dim() == 2: - if dim == 0: - return tensor.unsqueeze(0) - elif dim == 1: - return tensor.unsqueeze(-1) - else: - return tensor.flatten(0, dim - 1).flatten(2, -1) - - -def transform_2d(tensor: paddle.Tensor, dim: int) -> paddle.Tensor: - r"""Transform a tensor to [pre, dim] 2d dimensions. - - Args: - tensor (paddle.Tensor): The input tensor. - dim (int): The target dimension. - - Returns: - paddle.Tensor, the transformed tensor. - """ - if tensor.dim() == 1: - return tensor.unsqueeze(0) - else: - return tensor.flatten(0, dim - 1)