From 8c5351c3b0be0a85589ae9805dc1e08477f8348a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sun, 24 Nov 2024 12:58:25 +0800 Subject: [PATCH 1/2] add paddle-scatter --- jointContribution/paddle_scatter/__init__.py | 53 ++ .../paddle_scatter/composite/__init__.py | 11 + .../paddle_scatter/composite/logsumexp.py | 82 ++++ .../paddle_scatter/composite/softmax.py | 99 ++++ .../paddle_scatter/composite/std.py | 67 +++ .../paddle_scatter/csrc/atomics.cuh | 152 ++++++ .../paddle_scatter/csrc/index_info.cuh | 83 ++++ .../paddle_scatter/csrc/index_info.h | 67 +++ .../paddle_scatter/csrc/scatter_min_max.cc | 228 +++++++++ .../paddle_scatter/csrc/scatter_min_max.cu | 206 ++++++++ .../csrc/segment_coo_min_max.cc | 267 ++++++++++ .../csrc/segment_coo_min_max.cu | 375 ++++++++++++++ .../csrc/segment_csr_min_max.cc | 196 ++++++++ .../csrc/segment_csr_min_max.cu | 207 ++++++++ .../paddle_scatter/csrc/utils.cuh | 88 ++++ jointContribution/paddle_scatter/csrc/utils.h | 26 + jointContribution/paddle_scatter/scatter.py | 398 +++++++++++++++ .../paddle_scatter/segment_coo.py | 461 ++++++++++++++++++ .../paddle_scatter/segment_csr.py | 454 +++++++++++++++++ jointContribution/paddle_scatter/setup.py | 41 ++ jointContribution/paddle_scatter/testing.py | 23 + .../tests/composite/test_logsumexp.py | 39 ++ .../tests/composite/test_softmax.py | 51 ++ .../tests/composite/test_std.py | 27 + .../paddle_scatter/tests/test_broadcasting.py | 29 ++ .../paddle_scatter/tests/test_gather.py | 145 ++++++ .../paddle_scatter/tests/test_multi_gpu.py | 43 ++ .../paddle_scatter/tests/test_scatter.py | 296 +++++++++++ .../paddle_scatter/tests/test_segment.py | 297 +++++++++++ .../paddle_scatter/tests/test_zero_tensors.py | 45 ++ jointContribution/paddle_scatter/utils.py | 73 +++ 31 files changed, 4629 insertions(+) create mode 100644 jointContribution/paddle_scatter/__init__.py create mode 100644 jointContribution/paddle_scatter/composite/__init__.py create mode 100644 jointContribution/paddle_scatter/composite/logsumexp.py create mode 100644 jointContribution/paddle_scatter/composite/softmax.py create mode 100644 jointContribution/paddle_scatter/composite/std.py create mode 100644 jointContribution/paddle_scatter/csrc/atomics.cuh create mode 100644 jointContribution/paddle_scatter/csrc/index_info.cuh create mode 100644 jointContribution/paddle_scatter/csrc/index_info.h create mode 100644 jointContribution/paddle_scatter/csrc/scatter_min_max.cc create mode 100644 jointContribution/paddle_scatter/csrc/scatter_min_max.cu create mode 100644 jointContribution/paddle_scatter/csrc/segment_coo_min_max.cc create mode 100644 jointContribution/paddle_scatter/csrc/segment_coo_min_max.cu create mode 100644 jointContribution/paddle_scatter/csrc/segment_csr_min_max.cc create mode 100644 jointContribution/paddle_scatter/csrc/segment_csr_min_max.cu create mode 100644 jointContribution/paddle_scatter/csrc/utils.cuh create mode 100644 jointContribution/paddle_scatter/csrc/utils.h create mode 100644 jointContribution/paddle_scatter/scatter.py create mode 100644 jointContribution/paddle_scatter/segment_coo.py create mode 100644 jointContribution/paddle_scatter/segment_csr.py create mode 100644 jointContribution/paddle_scatter/setup.py create mode 100644 jointContribution/paddle_scatter/testing.py create mode 100644 jointContribution/paddle_scatter/tests/composite/test_logsumexp.py create mode 100644 jointContribution/paddle_scatter/tests/composite/test_softmax.py create mode 100644 jointContribution/paddle_scatter/tests/composite/test_std.py create mode 100644 jointContribution/paddle_scatter/tests/test_broadcasting.py create mode 100644 jointContribution/paddle_scatter/tests/test_gather.py create mode 100644 jointContribution/paddle_scatter/tests/test_multi_gpu.py create mode 100644 jointContribution/paddle_scatter/tests/test_scatter.py create mode 100644 jointContribution/paddle_scatter/tests/test_segment.py create mode 100644 jointContribution/paddle_scatter/tests/test_zero_tensors.py create mode 100644 jointContribution/paddle_scatter/utils.py diff --git a/jointContribution/paddle_scatter/__init__.py b/jointContribution/paddle_scatter/__init__.py new file mode 100644 index 000000000..3ff069e74 --- /dev/null +++ b/jointContribution/paddle_scatter/__init__.py @@ -0,0 +1,53 @@ +from .composite import scatter_log_softmax +from .composite import scatter_logsumexp +from .composite import scatter_softmax +from .composite import scatter_std +from .scatter import scatter +from .scatter import scatter_add +from .scatter import scatter_max +from .scatter import scatter_mean +from .scatter import scatter_min +from .scatter import scatter_mul +from .scatter import scatter_sum +from .segment_coo import gather_coo +from .segment_coo import segment_add_coo +from .segment_coo import segment_coo +from .segment_coo import segment_max_coo +from .segment_coo import segment_mean_coo +from .segment_coo import segment_min_coo +from .segment_coo import segment_sum_coo +from .segment_csr import gather_csr +from .segment_csr import segment_add_csr +from .segment_csr import segment_csr +from .segment_csr import segment_max_csr +from .segment_csr import segment_mean_csr +from .segment_csr import segment_min_csr +from .segment_csr import segment_sum_csr + +__all__ = [ + "scatter_sum", + "scatter_add", + "scatter_mul", + "scatter_mean", + "scatter_min", + "scatter_max", + "scatter", + "segment_sum_csr", + "segment_add_csr", + "segment_mean_csr", + "segment_min_csr", + "segment_max_csr", + "segment_csr", + "gather_csr", + "segment_sum_coo", + "segment_add_coo", + "segment_mean_coo", + "segment_min_coo", + "segment_max_coo", + "segment_coo", + "gather_coo", + "scatter_std", + "scatter_logsumexp", + "scatter_softmax", + "scatter_log_softmax", +] diff --git a/jointContribution/paddle_scatter/composite/__init__.py b/jointContribution/paddle_scatter/composite/__init__.py new file mode 100644 index 000000000..e446f7354 --- /dev/null +++ b/jointContribution/paddle_scatter/composite/__init__.py @@ -0,0 +1,11 @@ +from .logsumexp import scatter_logsumexp +from .softmax import scatter_log_softmax +from .softmax import scatter_softmax +from .std import scatter_std + +__all__ = [ + "scatter_std", + "scatter_logsumexp", + "scatter_softmax", + "scatter_log_softmax", +] diff --git a/jointContribution/paddle_scatter/composite/logsumexp.py b/jointContribution/paddle_scatter/composite/logsumexp.py new file mode 100644 index 000000000..3d66d040a --- /dev/null +++ b/jointContribution/paddle_scatter/composite/logsumexp.py @@ -0,0 +1,82 @@ +from typing import Optional + +import paddle +from paddle_scatter import scatter_max +from paddle_scatter import scatter_sum +from paddle_scatter.utils import broadcast + + +def scatter_logsumexp( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, + eps: float = 1e-12, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is logsumexp. (If dtype of `src` is int, output is still int.) + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + eps (float, optional): Eplison factor added to the sum of exponent values during + computation in case they are zero. Default is 1e-12. + + Returns: + paddle.Tensor, the reduced tensor by logsumexp reduction method. + """ + if not paddle.is_floating_point(src): + raise ValueError( + "`scatter_logsumexp` can only be computed over " + "tensors with floating point data types." + ) + + index = broadcast(index, src, dim) + eps = paddle.to_tensor(eps, dtype=src.dtype) + + if out is not None: + dim_size = out.shape[dim] + else: + if dim_size is None: + dim_size = int(index.max()) + 1 + + size = src.shape + size[dim] = dim_size + max_value_per_index = paddle.full( + size, + fill_value=float("-inf"), + dtype=src.dtype, + ) + scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] + max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) + recentered_score = src - max_per_src_element + recentered_score.masked_fill_(paddle.isnan(recentered_score), float("-inf")) + + orig_out: Optional[paddle.Tensor] = None + if out is not None: + orig_out = out.clone() + res = out.subtract(max_value_per_index).exp() + + sum_per_index = scatter_sum(recentered_score.exp(), index, dim, res, dim_size) + else: + sum_per_index = scatter_sum(recentered_score.exp(), index, dim, None, dim_size) + + res = sum_per_index.add(eps).log().add(max_value_per_index) + + if orig_out is None: + return res.nan_to_num_(neginf=0.0) + + mask = ~res.isfinite() + res[mask] = orig_out[mask] + paddle.assign(res, out) + return out diff --git a/jointContribution/paddle_scatter/composite/softmax.py b/jointContribution/paddle_scatter/composite/softmax.py new file mode 100644 index 000000000..b31b54d44 --- /dev/null +++ b/jointContribution/paddle_scatter/composite/softmax.py @@ -0,0 +1,99 @@ +from typing import Optional + +import paddle +from paddle_scatter import scatter_max +from paddle_scatter import scatter_sum +from paddle_scatter.utils import broadcast + + +def scatter_softmax( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is softmax. (If dtype of `src` is int, output is still int.) + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by softmax reduction method. + """ + if not paddle.is_floating_point(src): + raise ValueError( + "`scatter_softmax` can only be computed over tensors " + "with floating point data types." + ) + + index = broadcast(index, src, dim) + + max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] + max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) + + recentered_scores = src - max_per_src_element + recentered_scores_exp = recentered_scores.exp() + + sum_per_index = scatter_sum(recentered_scores_exp, index, dim, dim_size=dim_size) + normalizing_constants = sum_per_index.take_along_axis(indices=index, axis=dim) + + return recentered_scores_exp.divide(normalizing_constants) + + +def scatter_log_softmax( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + eps: float = 1e-12, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is log_softmax. (If dtype of `src` is int, output is still int.) + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + eps (float, optional): Eplison factor added to the normalizing constants during + computation in case they are zero. Default is 1e-12. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by log_softmax reduction method. + """ + if not paddle.is_floating_point(src): + raise ValueError( + "`scatter_log_softmax` can only be computed over " + "tensors with floating point data types." + ) + + index = broadcast(index, src, dim) + eps = paddle.to_tensor(eps, dtype=src.dtype) + + max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] + max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) + + recentered_scores = src - max_per_src_element + + sum_per_index = scatter_sum(recentered_scores.exp(), index, dim, dim_size=dim_size) + normalizing_constants = ( + sum_per_index.add(eps).log().take_along_axis(indices=index, axis=dim) + ) + + return recentered_scores.subtract(normalizing_constants) diff --git a/jointContribution/paddle_scatter/composite/std.py b/jointContribution/paddle_scatter/composite/std.py new file mode 100644 index 000000000..359e0b50d --- /dev/null +++ b/jointContribution/paddle_scatter/composite/std.py @@ -0,0 +1,67 @@ +from typing import Optional + +import paddle +from paddle_scatter import scatter_sum +from paddle_scatter.utils import broadcast + + +def scatter_std( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is std. (If dtype of `src` is int, output is still int.) + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + unbiased (bool, optional): Indicate whether to calculate biased std (divide by n) + or unbiased std (divide by n-1). Default is True. + + Returns: + paddle.Tensor, the reduced tensor by std reduction method. + """ + if out is not None: + dim_size = out.shape[dim] + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = paddle.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = broadcast(count, tmp, dim).clip(1) + mean = tmp.divide(count) + + var = src - mean.take_along_axis(indices=index, axis=dim) + var = var * var + res = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.subtract(paddle.to_tensor(1, dtype=src.dtype)).clip(1) + res = res.divide(count + 1e-6).sqrt() + + if out is not None: + paddle.assign(res, out) + return out + else: + return res diff --git a/jointContribution/paddle_scatter/csrc/atomics.cuh b/jointContribution/paddle_scatter/csrc/atomics.cuh new file mode 100644 index 000000000..224d29889 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/atomics.cuh @@ -0,0 +1,152 @@ +#pragma once + +#include "paddle/extension.h" + +#define ATOMIC(NAME) \ + template struct Atomic##NAME##IntegerImpl; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = ((size_t)address & 3) * 8; \ + uint32_t sum; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + sum = OP(val, scalar((old >> shift) & 0xff)); \ + old = (old & ~(0x000000ff << shift)) | (sum << shift); \ + old = atomicCAS(address_as_ui, assumed, old); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + uint32_t *address_as_ui = \ + (uint32_t *)((char *)address - ((size_t)address & 2)); \ + uint32_t old = *address_as_ui; \ + uint32_t sum; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \ + : scalar(old & 0xffff)); \ + newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \ + : (old & 0xffff0000) | sum; \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + uint32_t *address_as_ui = (uint32_t *)address; \ + uint32_t old = *address_as_ui; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##IntegerImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + unsigned long long *address_as_ull = (unsigned long long *)address; \ + unsigned long long old = *address_as_ull; \ + unsigned long long assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##DecimalImpl; \ + \ + \ + template struct Atomic##NAME##DecimalImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + int *address_as_i = (int *)address; \ + int old = *address_as_i; \ + int assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_i, assumed, \ + __float_as_int(OP(val, __int_as_float(assumed)))); \ + } while (assumed != old); \ + } \ + }; \ + \ + template struct Atomic##NAME##DecimalImpl { \ + inline __device__ void operator()(scalar *address, scalar val) { \ + unsigned long long int *address_as_ull = \ + (unsigned long long int *)address; \ + unsigned long long int old = *address_as_ull; \ + unsigned long long int assumed; \ + \ + do { \ + assumed = old; \ + old = atomicCAS( \ + address_as_ull, assumed, \ + __double_as_longlong(OP(val, __longlong_as_double(assumed)))); \ + } while (assumed != old); \ + } \ + }; + +#define OP(X, Y) max(Y, X) +ATOMIC(Max) +#undef OP +static inline __device__ void atomMax(uint8_t *address, uint8_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(int8_t *address, int8_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(int16_t *address, int16_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(int32_t *address, int32_t val) { + atomicMax(address, val); +} +static inline __device__ void atomMax(int64_t *address, int64_t val) { + AtomicMaxIntegerImpl()(address, val); +} +static inline __device__ void atomMax(float *address, float val) { + AtomicMaxDecimalImpl()(address, val); +} +static inline __device__ void atomMax(double *address, double val) { + AtomicMaxDecimalImpl()(address, val); +} + +#define OP(X, Y) min(Y, X) +ATOMIC(Min) +#undef OP +static inline __device__ void atomMin(uint8_t *address, uint8_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(int8_t *address, int8_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(int16_t *address, int16_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(int32_t *address, int32_t val) { + atomicMin(address, val); +} +static inline __device__ void atomMin(int64_t *address, int64_t val) { + AtomicMinIntegerImpl()(address, val); +} +static inline __device__ void atomMin(float *address, float val) { + AtomicMinDecimalImpl()(address, val); +} +static inline __device__ void atomMin(double *address, double val) { + AtomicMinDecimalImpl()(address, val); +} diff --git a/jointContribution/paddle_scatter/csrc/index_info.cuh b/jointContribution/paddle_scatter/csrc/index_info.cuh new file mode 100644 index 000000000..64ecd28de --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/index_info.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include "paddle/extension.h" + +#define MAX_TENSORINFO_DIMS 25 + +template struct TensorInfo { + TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS], + int st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25."); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } + } + + const T* data; + int dims; + int sizes[MAX_TENSORINFO_DIMS]; + int strides[MAX_TENSORINFO_DIMS]; +}; + + +template +TensorInfo getTensorInfo(const paddle::Tensor &tensor) { + int sizes[MAX_TENSORINFO_DIMS]; + int strides[MAX_TENSORINFO_DIMS]; + + int dims = tensor.shape().size(); + // int stride_i = 1; + for (int i = dims - 1; i >= 0; --i) { + sizes[i] = tensor.shape()[i]; + // strides[i] = stride_i; + // stride_i *= sizes[i]; + sizes[i] = tensor.strides()[i]; + } + + return TensorInfo(tensor.data(), dims, sizes, + strides); +} + +// Uses dynamic (runtime) instead of static (compiletime) dims +template +struct IndexToOffset { + static inline __host__ __device__ int get( + int linearId, + const TensorInfo& info) { + + int offset = 0; + + for (int i = info.dims - 1; i > 0; --i) { + int curDimIndex = linearId % info.sizes[i]; + int curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +// We need our own `IndexToOffset` implementation since we do not want to +// access the last element of the `indexptr`. +template +struct IndexPtrToOffset { + static inline __host__ __device__ int get( + int idx, + const TensorInfo& info) { + + int offset = idx % (info.sizes[info.dims - 1] - 1); + offset *= info.strides[info.dims - 1]; + idx /= info.sizes[info.dims - 1] - 1; + for (int i = info.dims - 2; i >= 0; --i) { + offset += (idx % info.sizes[i]) * info.strides[i]; + idx /= info.sizes[i]; + } + + return offset; + } +}; diff --git a/jointContribution/paddle_scatter/csrc/index_info.h b/jointContribution/paddle_scatter/csrc/index_info.h new file mode 100644 index 000000000..c2751898b --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/index_info.h @@ -0,0 +1,67 @@ +#pragma once + +#include "paddle/extension.h" + +#define MAX_TENSORINFO_DIMS 25 + +template struct TensorInfo { + TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS], + int st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25."); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } + } + + const T* data; + int dims; + int sizes[MAX_TENSORINFO_DIMS]; + int strides[MAX_TENSORINFO_DIMS]; +}; + +template +TensorInfo getTensorInfo(const paddle::Tensor &tensor) { + int sizes[MAX_TENSORINFO_DIMS]; + int strides[MAX_TENSORINFO_DIMS]; + + int dims = tensor.shape().size(); + // int stride_i = 1; + for (int i = dims - 1; i >= 0; --i) { + sizes[i] = tensor.shape()[i]; + // strides[i] = stride_i; + // stride_i *= sizes[i]; + strides[i] = tensor.strides()[i]; + } + + + return TensorInfo(tensor.data(), dims, sizes, + strides); +} + +template struct IndexToOffset { + static inline int get(int idx, const TensorInfo &info) { + int offset = 0; + for (int i = info.dims - 1; i >= 0; --i) { + offset += (idx % info.sizes[i]) * info.strides[i]; + idx /= info.sizes[i]; + } + return offset; + } +}; + +template struct IndexPtrToOffset { + static inline int get(int idx, const TensorInfo &info) { + int offset = idx % (info.sizes[info.dims - 1] - 1); + offset *= info.strides[info.dims - 1]; + idx /= info.sizes[info.dims - 1] - 1; + for (int i = info.dims - 2; i >= 0; --i) { + offset += (idx % info.sizes[i]) * info.strides[i]; + idx /= info.sizes[i]; + } + return offset; + } +}; diff --git a/jointContribution/paddle_scatter/csrc/scatter_min_max.cc b/jointContribution/paddle_scatter/csrc/scatter_min_max.cc new file mode 100644 index 000000000..1828d57ae --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/scatter_min_max.cc @@ -0,0 +1,228 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void scatter_min_max_cpu_forward_kernel(const data_t* x_data, + const index_t* index_data, + const std::vector& return_shape, + const std::vector& x_dims, + const std::string& reduce, + const TensorInfo& index_info, + int64_t x_numel, + int64_t dim, + bool post_process, + data_t* out_data, + index_t* arg_out_data) { + using MPType = typename MPTypeTrait::Type; + auto B = 1; + for (auto i = 0; i < dim; ++i) + B *= x_dims[i]; + auto E = x_dims[dim]; + auto K = x_numel / (B * E); + auto N = return_shape[dim]; + + int64_t i, idx, out_data_idx; + for (auto b = 0; b < B; b++) { + for (auto e = 0; e < E; e++) { + for (auto k = 0; k < K; k++) { + i = b * E * K + e * K + k; + idx = index_info.data[IndexToOffset::get(i, index_info)]; + out_data_idx = b * N * K + idx * K + k; + if ((reduce == "min" && x_data[i] < out_data[out_data_idx]) || + (reduce == "max" && x_data[i] > out_data[out_data_idx])) { + out_data[out_data_idx] = x_data[i]; + arg_out_data[out_data_idx] = e; + } + } + } + } + + if (post_process) { + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + for (auto i = 0; i < out_numel; ++i) { + if (out_data[i] == init_val) + out_data[i] = static_cast(0.0); + } + } +} + +std::vector scatter_min_max_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce, + int64_t dim) { + CHECK_CPU(index); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() == index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "scatter_min_max_cpu_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out, static_cast(std::numeric_limits::max())); + else + paddle::experimental::fill_(out, static_cast(std::numeric_limits::lowest())); + } + + bool post_process = (!init) ? true : false; + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + scatter_min_max_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + arg_out.data()); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + scatter_min_max_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + arg_out.data()); + break; + } + default: + PD_THROW( + "function scatter_min_max_cpu_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + + +std::vector scatter_min_max_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce, + int64_t dim); + +std::vector ScatterMinMaxForward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce, + int64_t dim) { + if (x.is_cpu()) { + return scatter_min_max_cpu_forward(x, index, init, return_shape, reduce, dim); + } else if (x.is_gpu()) { + return scatter_min_max_cuda_forward(x, index, init, return_shape, reduce, dim); + } else { + PD_THROW("Unsupported device type for forward function of custom scatter_min_max operator."); + } +} + +std::vector ScatterMinMaxBackward(const paddle::Tensor& x, + const paddle::Tensor& arg_out, + const paddle::Tensor& grad_out, + int64_t dim) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom scatter_min_max operator."); + } + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out, grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; +} + +std::vector> ScatterMinMaxFWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& init_shape, + std::vector return_shape, + std::string reduce, + int64_t dim) { + return {return_shape, return_shape}; +} + +std::vector ScatterMinMaxFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& init_dtype) { + return {x_dtype, index_dtype}; +} + +std::vector> ScatterMinMaxBWInferShape(const std::vector& x_shape, + const std::vector& arg_out_shape, + const std::vector& grad_out_shape, + int64_t dim) { + return {x_shape}; +} + +std::vector ScatterMinMaxBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& arg_out_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP(custom_scatter_min_max) + .Inputs({"X", "Index", paddle::Optional("Init")}) + .Outputs({"Out", "ArgOut"}) + .Attrs({"return_shape: std::vector", + "reduce: std::string", + "dim: int64_t"}) + .SetKernelFn(PD_KERNEL(ScatterMinMaxForward)) + .SetInferShapeFn(PD_INFER_SHAPE(ScatterMinMaxFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ScatterMinMaxFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_scatter_min_max) + .Inputs({"X", "ArgOut", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"dim: int64_t"}) + .SetKernelFn(PD_KERNEL(ScatterMinMaxBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(ScatterMinMaxBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ScatterMinMaxBWInferDtype)); diff --git a/jointContribution/paddle_scatter/csrc/scatter_min_max.cu b/jointContribution/paddle_scatter/csrc/scatter_min_max.cu new file mode 100644 index 000000000..5fccb6f6a --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/scatter_min_max.cu @@ -0,0 +1,206 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(N) (N + THREADS - 1) / THREADS + + +enum ReductionType { MIN, MAX }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX} +}; + +template +__global__ void scatter_min_max_cuda_forward_kernel(const data_t* x_data, + const TensorInfo index_info, + ReductionType reduce_type, + int numel, + int E, + int K, + int N, + mp_t* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int b = tid / (E * K); + int k = tid % K; + + if (tid < numel) { + int64_t idx = index_info.data[IndexToOffset::get(tid, index_info)]; + + switch(reduce_type) { + case MIN: + { + atomMin(out_data + b * N * K + idx * K + k, + static_cast(x_data[tid])); + break; + } + case MAX: + { + atomMax(out_data + b * N * K + idx * K + k, + static_cast(x_data[tid])); + break; + } + } + + } +} + +template +__global__ void scatter_arg_min_max_cuda_forward_kernel(const data_t* x_data, + const TensorInfo index_info, + int numel, + int E, + int K, + int N, + mp_t* out_data, + index_t *arg_out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int b = tid / (E * K); + int e = (tid / K) % E; + int k = tid % K; + + if (tid < numel) { + int64_t idx = index_info.data[IndexToOffset::get(tid, index_info)]; + + if (static_cast(x_data[tid]) == out_data[b * N * K + idx * K + k]) { + arg_out_data[b * N * K + idx * K + k] = e; + } + } +} + +template +__global__ void post_process_kernel(data_t init_val, + int numel, + data_t* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < numel) { + if (out_data[tid] == init_val) + out_data[tid] = static_cast(0.0); + } +} + +std::vector scatter_min_max_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce, + int64_t dim) { + CHECK_CUDA(index); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() == index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + + auto B = 1; + for (auto i = 0; i < dim; ++i) + B *= x_dims[i]; + auto E = x_dims[dim]; + auto K = x.numel() / (B * E); + auto N = return_shape[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "scatter_min_max_cuda_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + paddle::Tensor out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out_mp = paddle::empty(return_shape, paddle::DataType::FLOAT32, x.place()); + } else { + out_mp = out; + } + + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out_mp, std::numeric_limits::max()); + else + paddle::experimental::fill_(out_mp, std::numeric_limits::lowest()); + } + + const data_t* x_data = x.data(); + MPType* out_data = out_mp.data(); + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + int* arg_out_data = arg_out.data(); + scatter_min_max_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), x.numel(), + E, K, N, out_data); + + scatter_arg_min_max_cuda_forward_kernel + <<>>( + x_data, index_info, x.numel(), + E, K, N, out_data, arg_out_data); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + int64_t* arg_out_data = arg_out.data(); + scatter_min_max_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), x.numel(), + E, K, N, out_data); + + scatter_arg_min_max_cuda_forward_kernel + <<>>( + x_data, index_info, x.numel(), + E, K, N, out_data, arg_out_data); + break; + } + default: + PD_THROW( + "function scatter_min_max_cuda_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out = paddle::experimental::cast(out_mp, x.dtype()); + } + + if (!init) { + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + post_process_kernel + <<>>( + init_val, out_numel, out.data() + ); + } + + })); + + return {out, arg_out}; +} diff --git a/jointContribution/paddle_scatter/csrc/segment_coo_min_max.cc b/jointContribution/paddle_scatter/csrc/segment_coo_min_max.cc new file mode 100644 index 000000000..33fadfcc8 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_coo_min_max.cc @@ -0,0 +1,267 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void segment_coo_min_max_cpu_forward_kernel(const data_t* x_data, + const index_t* index_data, + const std::vector& return_shape, + const std::vector& x_dims, + const std::string& reduce, + const TensorInfo& index_info, + int64_t x_numel, + int64_t dim, + bool post_process, + data_t* out_data, + index_t* arg_out_data) { + using MPType = typename MPTypeTrait::Type; + auto B = 1; + for (auto i = 0; i < dim; ++i) + B *= x_dims[i]; + auto E = x_dims[dim]; + auto K = x_numel / (B * E); + auto N = return_shape[dim]; + + auto stride = index_info.strides[index_info.dims - 1]; + std::vector args(K); + std::vector vals(K); + + int64_t idx, next_idx, row_start; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = static_cast(out_data[b * N * K + k]); + + row_start = 0; + for (auto e = 0; e < E; e++) { + // update + for (auto k = 0; k < K; k++) { + auto cmp = static_cast(x_data[b * E * K + e * K + k]); + if ((reduce == "min" && cmp < vals[k]) || + (reduce == "max" && cmp > vals[k])) { + vals[k] = cmp; + args[k] = e; + } + } + //write + if (e == E - 1) { + for (auto k = 0; k < K; k++) { + auto idx_k = b * N * K + idx * K + k; + if (E - row_start > 0) { + out_data[idx_k] = static_cast(vals[k]); + arg_out_data[idx_k] = args[k]; + } else { + out_data[idx_k] = static_cast(0); + } + } + } else { + next_idx = index_info.data[offset + (e + 1) * stride]; + assert(idx <= next_idx); + + if (idx != next_idx) { + //write + for (auto k = 0; k < K; k++) { + auto idx_k = b * N * K + idx * K + k; + if (e + 1 - row_start > 0) { + out_data[idx_k] = static_cast(vals[k]); + arg_out_data[idx_k] = args[k]; + } else { + out_data[idx_k] = static_cast(0); + } + vals[k] = static_cast(out_data[b * N * K + next_idx * K + k]); + } + row_start = e + 1; + } + + idx = next_idx; + } + } + } + + if (post_process) { + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + for (auto i = 0; i < out_numel; ++i) { + if (out_data[i] == init_val) + out_data[i] = static_cast(0.0); + } + } +} + +std::vector segment_coo_min_max_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CPU(index); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "segment_coo_min_max_cpu_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out, static_cast(std::numeric_limits::max())); + else + paddle::experimental::fill_(out, static_cast(std::numeric_limits::lowest())); + } + + bool post_process = (!init) ? true : false; + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + segment_coo_min_max_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + arg_out.data()); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + segment_coo_min_max_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + arg_out.data()); + break; + } + default: + PD_THROW( + "function segment_coo_min_max_cpu_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + + +std::vector segment_coo_min_max_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce); + +std::vector SegmentCooMinMaxForward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce) { + if (x.is_cpu()) { + return segment_coo_min_max_cpu_forward(x, index, init, return_shape, reduce); + } else if (x.is_gpu()) { + return segment_coo_min_max_cuda_forward(x, index, init, return_shape, reduce); + } else { + PD_THROW("Unsupported device type for forward function of custom segment_coo_min_max operator."); + } +} + +std::vector SegmentCooMinMaxBackward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::Tensor& arg_out, + const paddle::Tensor& grad_out) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom segment_coo_min_max operator."); + } + int64_t dim = index.shape().size() - 1; + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out, grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; +} + +std::vector> SegmentCooMinMaxFWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& init_shape, + std::vector return_shape, + std::string reduce) { + return {return_shape, return_shape}; +} + +std::vector SegmentCooMinMaxFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& init_dtype) { + return {x_dtype, index_dtype}; +} + +std::vector> SegmentCooMinMaxBWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const std::vector& arg_out_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector SegmentCooMinMaxBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::DataType& arg_out_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP(custom_segment_coo_min_max) + .Inputs({"X", "Index", paddle::Optional("Init")}) + .Outputs({"Out", "ArgOut"}) + .Attrs({"return_shape: std::vector", + "reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCooMinMaxForward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCooMinMaxFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCooMinMaxFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_segment_coo_min_max) + .Inputs({"X", "Index", "ArgOut", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(SegmentCooMinMaxBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCooMinMaxBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCooMinMaxBWInferDtype)); diff --git a/jointContribution/paddle_scatter/csrc/segment_coo_min_max.cu b/jointContribution/paddle_scatter/csrc/segment_coo_min_max.cu new file mode 100644 index 000000000..576152f28 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_coo_min_max.cu @@ -0,0 +1,375 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + + +enum ReductionType { MIN, MAX }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX} +}; + +template +__global__ void +segment_coo_min_max_cuda_forward_kernel(const data_t* x_data, + const TensorInfo index_info, + ReductionType reduce_type, + mp_t* out_data, + size_t E, + size_t N) { + + // Each thread processes exactly one entry. Within a warp, we perform a + // parallel reduction across equal indices, and write the intermediate + // result via atomics. + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int lane_idx = row_idx & (32 - 1); + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = IndexToOffset::get(row_idx, index_info); + int64_t idx = index_info.data[offset], next_idx; + int out_idx = (row_idx / D) * N + idx; + + mp_t val = HAS_VAL ? static_cast(x_data[row_idx]) : (mp_t)1, tmp; + +#pragma unroll + for (int i = 1; i < 32; i *= 2) { + // Parallel reduction inside a single warp. + tmp = SHFL_UP_SYNC(FULL_MASK, val, i); + next_idx = SHFL_UP_SYNC(FULL_MASK, idx, i); + if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { + assert(idx >= next_idx); + if (idx == next_idx) { + if ((reduce_type == MIN && tmp < val) || + (reduce_type == MAX && tmp > val)) + val = tmp; + } + } + } + + next_idx = SHFL_DOWN_SYNC(FULL_MASK, idx, 1); + if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || + idx != next_idx) { + switch(reduce_type) { + case MIN: + { + atomMin(out_data + out_idx, val); + break; + } + case MAX: + atomMax(out_data + out_idx, val); + break; + } + } + } +} + +template +__global__ void segment_coo_min_max_broadcast_cuda_forward_kernel(const data_t *x_data, + const TensorInfo index_info, + ReductionType reduce_type, + mp_t *out_data, + size_t E, + size_t K, + size_t N) { + + // Each thread processes a single column and `TB` index entries. Coalesced + // read and write is performed in column-major order. The intermediate + // results are written via atomics. + + int D = index_info.sizes[index_info.dims - 1]; + int E_1 = E / D; + int E_2 = (D - 1) + TB - ((D - 1) % TB); + + int row_idx = blockIdx.x * blockDim.y + threadIdx.y; + int col_idx = blockIdx.y * blockDim.x + threadIdx.x; + + int dim_start = (row_idx * TB) / E_2; + int row_start = (row_idx * TB) % E_2; + + if (dim_start < E_1 && col_idx < K) { + + int offset = IndexToOffset::get(dim_start * D + row_start, index_info); + int idx1 = __ldg(index_info.data + offset), idx2; + + mp_t val = static_cast(x_data[K * (dim_start * D + row_start) + col_idx]); + +#pragma unroll + for (int i = 1; i < TB; i++) { + if (row_start + i >= D) + break; + + idx2 = __ldg(index_info.data + offset + + i * index_info.strides[index_info.dims - 1]); + assert(idx1 <= idx2); + if (idx1 == idx2) { + mp_t tmp = static_cast(x_data[K * (dim_start * D + row_start + i) + col_idx]); + if ((reduce_type == MIN && tmp < val) || + (reduce_type == MAX && tmp > val)) + val = tmp; + } else { + switch(reduce_type) { + case MIN: + { + atomMin(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + case MAX: + atomMax(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + val = x_data[K * (dim_start * D + row_start + i) + col_idx]; + } + + idx1 = idx2; + } + + switch(reduce_type) { + case MIN: + { + atomMin(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + case MAX: + atomMax(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + } +} + + +template +__global__ void segment_coo_arg_kernel(const data_t *x_data, + const TensorInfo index_info, + mp_t *out_data, + index_t *arg_out_data, + size_t E, + size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = IndexToOffset::get(row_idx, index_info); + index_t idx = index_info.data[offset]; + int out_idx = (row_idx / D) * N + idx; + + mp_t val = __ldg(out_data + out_idx); + if (static_cast(x_data[row_idx]) == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void segment_coo_arg_broadcast_kernel(const data_t *x_data, + const TensorInfo index_info, + mp_t *out_data, + index_t *arg_out_data, + size_t E, + size_t K, + size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E && col_idx < K) { + int offset = IndexToOffset::get(row_idx, index_info); + int idx = __ldg(index_info.data + offset); + int out_idx = ((row_idx / D) * N + idx) * K + col_idx; + + mp_t val = __ldg(out_data + out_idx); + if (static_cast(x_data[thread_idx]) == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void post_process_kernel(data_t init_val, + int numel, + data_t* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < numel) { + if (out_data[tid] == init_val) + out_data[tid] = static_cast(0.0); + } +} + +std::vector segment_coo_min_max_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce) { + CHECK_CUDA(index); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto dim = index_dims.size() - 1; + paddle::Tensor arg_out; + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + + auto E = index.numel(); + auto E_2 = index.shape()[dim]; + auto E_1 = index.numel() / E_2; + auto K = x.numel() / E; + auto N = out.shape()[dim]; + auto avg_len = (float)E_2 / (float)N; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "segment_coo_min_max_cuda_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + paddle::Tensor out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out_mp = paddle::empty(return_shape, paddle::DataType::FLOAT32, x.place()); + } else { + out_mp = out; + } + + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out_mp, std::numeric_limits::max()); + else + paddle::experimental::fill_(out_mp, std::numeric_limits::lowest()); + } + + const data_t* x_data = x.data(); + MPType* out_data = out_mp.data(); + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + int* arg_out_data = arg_out.data(); + if (K == 1) + segment_coo_min_max_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, N); + + else if (avg_len <= 8) + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 16) + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 32) + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + + if (K == 1) + segment_coo_arg_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, K, N); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + int64_t* arg_out_data = arg_out.data(); + if (K == 1) + segment_coo_min_max_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, N); + + else if (avg_len <= 8) + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 16) + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 32) + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else + segment_coo_min_max_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + + if (K == 1) + segment_coo_arg_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, K, N); + break; + } + default: + PD_THROW( + "function segment_coo_min_max_cuda_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out = paddle::experimental::cast(out_mp, x.dtype()); + } + + if (!init) { + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + post_process_kernel + <<>>( + init_val, out_numel, out.data() + ); + } + + })); + + return {out, arg_out}; +} diff --git a/jointContribution/paddle_scatter/csrc/segment_csr_min_max.cc b/jointContribution/paddle_scatter/csrc/segment_csr_min_max.cc new file mode 100644 index 000000000..1563e7091 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_csr_min_max.cc @@ -0,0 +1,196 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void segment_csr_min_max_cpu_forward_kernel(const data_t* x_data, + const std::vector& indptr_shape, + const TensorInfo& indptr_info, + const std::string& reduce, + int stride, + int64_t dim, + int N, + int K, + int E, + data_t* out_data, + index_t* arg_out_data) { + using MPType = typename MPTypeTrait::Type; + std::vector args(K); + std::vector vals(K); + index_t row_start, row_end; + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + offset = (n / (indptr_shape[dim] - 1)) * E * K; + for (auto k = 0; k < K; k++) { + if (reduce == "min") + vals[k] = static_cast(std::numeric_limits::max()); + else + vals[k] = static_cast(std::numeric_limits::lowest()); + } + + for (auto e = row_start; e < row_end; e++) { + for (auto k = 0; k < K; k++) { + // update + auto cmp = static_cast(x_data[offset + e * K + k]); + if ((reduce == "min" && cmp < vals[k]) || + (reduce == "max" && cmp > vals[k])) { + vals[k] = cmp; + args[k] = e; + } + } + } + + for (auto k = 0; k < K; k++) { + // write + auto idx = n * K + k; + if (row_end - row_start > 0) { + out_data[idx] = static_cast(vals[k]); + arg_out_data[idx] = args[k]; + } else { + out_data[idx] = static_cast(0); + } + } + } +} + +std::vector segment_csr_min_max_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CPU(indptr); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + out = paddle::empty(return_shape, x.dtype(), x.place()); + + paddle::Tensor arg_out; + arg_out = paddle::experimental::full_like(out, x_dims[dim], indptr.dtype(), indptr.place()); + + auto N = return_shape[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = out.numel() / N; + auto E = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "segment_csr_min_max_cpu_forward_kernel", ([&] { + + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + segment_csr_min_max_cpu_forward_kernel( + x.data(), indptr_dims, indptr_info, reduce, + stride, dim, N, K, E, out.data(), arg_out.data()); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + segment_csr_min_max_cpu_forward_kernel( + x.data(), indptr_dims, indptr_info, reduce, + stride, dim, N, K, E, out.data(), arg_out.data()); + break; + } + default: + PD_THROW( + "function segment_csr_min_max_cpu_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + + +std::vector segment_csr_min_max_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const std::vector& return_shape, + const std::string& reduce); + +std::vector SegmentCsrMinMaxForward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const std::vector& return_shape, + const std::string& reduce) { + if (x.is_cpu()) { + return segment_csr_min_max_cpu_forward(x, indptr, return_shape, reduce); + } else if (x.is_gpu()) { + return segment_csr_min_max_cuda_forward(x, indptr, return_shape, reduce); + } else { + PD_THROW("Unsupported device type for forward function of custom segment_csr_min_max operator."); + } +} + +std::vector SegmentCsrMinMaxBackward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::Tensor& arg_out, + const paddle::Tensor& grad_out) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom segment_csr_min_max operator."); + } + int64_t dim = indptr.shape().size() - 1; + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out, grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; +} + +std::vector> SegmentCsrMinMaxFWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + std::vector return_shape, + std::string reduce) { + return {return_shape, return_shape}; +} + +std::vector SegmentCsrMinMaxFWInferDtype(paddle::DataType x_dtype, + paddle::DataType indptr_dtype) { + return {x_dtype, indptr_dtype}; +} + +std::vector> SegmentCsrMinMaxBWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const std::vector& arg_out_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector SegmentCsrMinMaxBWInferDtype(paddle::DataType x_dtype, + paddle::DataType indptr_dtype, + paddle::DataType arg_out_dtype, + paddle::DataType grad_out_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP(custom_segment_csr_min_max) + .Inputs({"X", "Indptr"}) + .Outputs({"Out", "ArgOut"}) + .Attrs({"return_shape: std::vector", + "reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCsrMinMaxForward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCsrMinMaxFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCsrMinMaxFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_segment_csr_min_max) + .Inputs({"X", "Indptr", "ArgOut", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(SegmentCsrMinMaxBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCsrMinMaxBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCsrMinMaxBWInferDtype)); diff --git a/jointContribution/paddle_scatter/csrc/segment_csr_min_max.cu b/jointContribution/paddle_scatter/csrc/segment_csr_min_max.cu new file mode 100644 index 000000000..7e3cffae1 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_csr_min_max.cu @@ -0,0 +1,207 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + + +enum ReductionType { MIN, MAX }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX} +}; + +template +__global__ void segment_csr_min_max_kernel(const data_t *x_data, + const TensorInfo indptr_info, + ReductionType reduce_type, + data_t *out_data, + index_t *arg_out_data, + size_t N, + size_t E) { + + // Each warp processes exactly `32/TB` rows and aggregates all row values + // via a parallel reduction. + + using MPType = typename MPTypeTrait::Type; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx & (TB - 1); + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + index_t row_start = __ldg(indptr_info.data + offset); + index_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + data_t val = (reduce_type == MIN) ? static_cast(std::numeric_limits::max()) : static_cast(std::numeric_limits::lowest()); + index_t arg, arg_tmp; + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (index_t x_idx = row_start + lane_idx; x_idx < row_end; + x_idx += TB) { + // update + auto cmp = x_data[offset + x_idx]; + if ((reduce_type == MIN && cmp < val) || + (reduce_type == MAX && cmp > val)) { + val = cmp; + arg = x_idx; + } + } + +#pragma unroll + for (int i = TB / 2; i > 0; i /= 2) { + // Parallel reduction inside a single warp. + arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i); + // update + MPType cmp = SHFL_DOWN_SYNC(FULL_MASK, static_cast(val), i); + if ((reduce_type == MIN && cmp < static_cast(val)) || + (reduce_type == MAX && cmp > static_cast(val))) { + val = static_cast(cmp); + arg = arg_tmp; + } + } + + if (lane_idx == 0) { + // write + if (row_end - row_start > 0) { + out_data[row_idx] = val; + arg_out_data[row_idx] = arg; + } else { + out_data[row_idx] = static_cast(0); + } + } + } +} + +template +__global__ void segment_csr_broadcast_min_max_kernel(const data_t *x_data, + const TensorInfo indptr_info, + ReductionType reduce_type, + data_t *out_data, + index_t *arg_out_data, + size_t N, + size_t K, + size_t E) { + + // Each thread processes exactly one row. It turned out that is more + // efficient than using shared memory due to avoiding synchronization + // barriers. + + using MPType = typename MPTypeTrait::Type; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + index_t row_start = __ldg(indptr_info.data + offset); + index_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + data_t val = (reduce_type == MIN) ? static_cast(std::numeric_limits::max()) : static_cast(std::numeric_limits::lowest()); + index_t arg; + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (index_t x_idx = row_start; x_idx < row_end; x_idx++) { + // update + auto cmp = x_data[offset + K * x_idx + lane_idx]; + if ((reduce_type == MIN && cmp < val) || + (reduce_type == MAX && cmp > val)) { + val = cmp; + arg = x_idx; + } + } + + // write + if (row_end - row_start > 0) { + out_data[thread_idx] = val; + arg_out_data[thread_idx] = arg; + } else { + out_data[thread_idx] = static_cast(0); + } + } +} + +std::vector segment_csr_min_max_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CUDA(indptr); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + out = paddle::empty(return_shape, x.dtype(), x.place()); + + paddle::Tensor arg_out; + arg_out = paddle::experimental::full_like(out, x_dims[dim], indptr.dtype(), indptr.place()); + + auto N = return_shape[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = out.numel() / N; + auto E = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "segment_csr_min_max_cuda_forward_kernel", ([&] { + + const data_t* x_data = x.data(); + data_t* out_data = out.data(); + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int* arg_out_data = arg_out.data(); + if (K == 1) { + segment_csr_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, K, E); + } + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int64_t* arg_out_data = arg_out.data(); + if (K == 1) { + segment_csr_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, K, E); + } + break; + } + default: + PD_THROW( + "function segment_csr_min_max_cuda_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out, arg_out}; +} diff --git a/jointContribution/paddle_scatter/csrc/utils.cuh b/jointContribution/paddle_scatter/csrc/utils.cuh new file mode 100644 index 000000000..d5759541b --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/utils.cuh @@ -0,0 +1,88 @@ +#pragma once + +#include "paddle/extension.h" + + +#define CHECK_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") +#define CHECK_INPUT(x) PD_CHECK(x, "Input mismatch") + +///////// Basic Marco /////////// + +#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ + case enum_type: { \ + using HINT = type; \ + __VA_ARGS__(); \ + break; \ + } + +#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ + PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__) + +///////// Floating and Integral Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_INTEGRAL_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, __VA_ARGS__) + +///////// Floating and Integral Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_AND_INTEGRAL_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, \ + SPECIFIED_TYPE1, \ + ::phi::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + SPECIFIED_TYPE2, \ + ::phi::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +template +class MPTypeTrait { + public: + using Type = T; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +#ifdef USE_ROCM +#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta) +#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta) +#else +#define SHFL_UP_SYNC __shfl_up_sync +#define SHFL_DOWN_SYNC __shfl_down_sync +#endif \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/utils.h b/jointContribution/paddle_scatter/csrc/utils.h new file mode 100644 index 000000000..851a753e9 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/utils.h @@ -0,0 +1,26 @@ +#pragma once + +#include "paddle/extension.h" + + +#define CHECK_CPU(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") +#define CHECK_INPUT(x) PD_CHECK(x, "Input mismatch") + + +template +class MPTypeTrait { + public: + using Type = T; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; diff --git a/jointContribution/paddle_scatter/scatter.py b/jointContribution/paddle_scatter/scatter.py new file mode 100644 index 000000000..c513ecf52 --- /dev/null +++ b/jointContribution/paddle_scatter/scatter.py @@ -0,0 +1,398 @@ +from typing import Optional +from typing import Tuple + +import paddle +from paddle import assign +from paddle import divide +from paddle import floor_divide +from paddle import full +from paddle import full_like +from paddle import ones +from paddle import put_along_axis +from paddle import where +from paddle import zeros +from paddle_scatter_min_max_ops import custom_scatter_min_max + +from .utils import broadcast + + +def scatter_sum( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is sum. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by sum reduction method. + """ + index = broadcast(index, src, dim) + if out is None: + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + arr = zeros(size, dtype=src.dtype) + if src.numel() == 0: + return arr + return put_along_axis(arr, indices=index, values=src, axis=dim, reduce="add") + else: + if src.numel() == 0: + return out + result = put_along_axis(out, indices=index, values=src, axis=dim, reduce="add") + assign(result, out) + return out + + +def scatter_add( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is sum. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by sum reduction method. + """ + return scatter_sum(src, index, dim, out, dim_size) + + +def scatter_mul( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is multiply. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by multiply reduction method. + """ + index = broadcast(index, src, dim) + if out is None: + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + arr = ones(size, dtype=src.dtype) + if src.numel() == 0: + return arr + return put_along_axis(arr, indices=index, values=src, axis=dim, reduce="mul") + else: + if src.numel() == 0: + return out + result = put_along_axis(out, indices=index, values=src, axis=dim, reduce="mul") + assign(result, out) + return out + + +def scatter_mean( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is mean. (If dtype of `src` is int, output is still int.) + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by mean reduction method. + """ + sums = scatter_sum(src, index, dim, out, dim_size) + dim_size = sums.shape[dim] + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones_tensor = ones(index.shape, dtype=src.dtype) + tmp = scatter_sum(ones_tensor, index, index_dim, None, dim_size) + count = where(tmp < 1, full_like(tmp, 1), tmp, name="where") + count = broadcast(count, sums, dim) + if sums.is_floating_point(): + result = divide(sums, count) + else: + result = floor_divide(sums, count) + if out is None: + return result + else: + assign(result, out) + return out + + +def scatter_min( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is min. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor of min result. + Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor], the reduced min tensor and arg_min tensor. + """ + if dim < 0: + dim = dim + index.dim() + index = broadcast(index, src, dim) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + + if out is None: + if src.numel() == 0: + return ( + zeros(size, dtype=src.dtype), + full(size, src.shape[dim], index.dtype), + ) + return custom_scatter_min_max(src, index, None, size, "min", dim) + else: + if src.numel() == 0: + return (out, full(size, src.shape[dim], index.dtype)) + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result, arg_result = custom_scatter_min_max( + src, index, out, out.shape, "min", dim + ) + assign(result, out) + return out, arg_result + + +def scatter_max( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`, + the reduction method is max. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor of max result. + Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor], the reduced max tensor and arg_max tensor. + """ + if dim < 0: + dim = dim + index.dim() + index = broadcast(index, src, dim) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + + if out is None: + if src.numel() == 0: + return ( + zeros(size, dtype=src.dtype), + full(size, src.shape[dim], index.dtype), + ) + return custom_scatter_min_max(src, index, None, size, "max", dim) + else: + if src.numel() == 0: + return (out, full(size, src.shape[dim], index.dtype)) + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result, arg_result = custom_scatter_min_max( + src, index, out, out.shape, "max", dim + ) + assign(result, out) + return out, arg_result + + +def scatter( + src: paddle.Tensor, + index: paddle.Tensor, + dim: int = -1, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along a given axis`dim`. + + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `dim` and by the corresponding + value in `index` for dimension `dim`. The applied reduction is defined + via the `reduce` argument. + + Formally, if `src` and `index` are :math:`n`-dimensional + tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` + and `dim` = `i`, then `out` must be an :math:`n`-dimensional + tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. + Moreover, the values of `index` must be between :math:`0` and + :math:`y - 1`, although no specific ordering of indices is required. + The `index` tensor supports broadcasting in case its dimensions do + not match with `src`. + + For one-dimensional tensors with `reduce="sum"`, the operation + computes + + $$ + \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j + $$ + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Notes: + `index` tensor supports broadcasting, and its shape should be either + :math:`(x_i, )` or :math:`(*_0, ..., *_{i-1}, x_i)`, + where :math:`*_k (k <= i-1)` should be either :math:`1` or :math:`x_k`. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to scatter. The dimension + of index should either be 1-D or :math:`i+1`-D. See Notes for more + details. + dim (int, optional): The axis along which to index. Default is -1. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + reduce (str, optional): The reduce operation supports `"sum"`, `"add"`, `"mul"`, + `"mean"`, `"min"` or `"max"`. Default is `"sum"`. + + Returns: + paddle.Tensor, the reduced tensor. + + Examples: + >>> from paddle_scatter import scatter + + >>> src = paddle.randn([10, 6, 64]) + >>> index = paddle.tensor([0, 1, 0, 1, 2, 1]) + + >>> # Broadcasting in the first and last dim + >>> out = scatter(src, index, dim=1, reduce="sum") + >>> print(out.shape) + [10, 3, 64] + + >>> # Specify `dim_size` + >>> out = scatter(src, index, dim=1, dim_size=4, reduce="sum") + >>> print(out.shape) + [10, 4, 64] + + >>> # Specify `out` + >>> out = paddle.empty([10, 3, 64]) + >>> scatter(src, index, dim=1, out=out, reduce="sum") + >>> print(out.shape) + [10, 3, 64] + """ + if reduce == "sum" or reduce == "add": + return scatter_sum(src, index, dim, out, dim_size) + if reduce == "mul": + return scatter_mul(src, index, dim, out, dim_size) + elif reduce == "mean": + return scatter_mean(src, index, dim, out, dim_size) + elif reduce == "min": + return scatter_min(src, index, dim, out, dim_size)[0] + elif reduce == "max": + return scatter_max(src, index, dim, out, dim_size)[0] + else: + raise ValueError diff --git a/jointContribution/paddle_scatter/segment_coo.py b/jointContribution/paddle_scatter/segment_coo.py new file mode 100644 index 000000000..75af63aff --- /dev/null +++ b/jointContribution/paddle_scatter/segment_coo.py @@ -0,0 +1,461 @@ +from typing import Optional +from typing import Tuple + +import paddle +from paddle import assign +from paddle import full +from paddle import slice +from paddle import take_along_axis +from paddle import to_tensor +from paddle import zeros +from paddle.geometric import segment_mean +from paddle.geometric import segment_sum +from paddle.nn.functional import pad +from paddle_scatter_min_max_ops import custom_segment_coo_min_max + +from .utils import broadcast +from .utils import transform_3d + + +def segment_sum_coo( + src: paddle.Tensor, + index: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. The reduction method is sum. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The sorted indices of elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by sum reduction method. + """ + index_shape = index.shape + dim = len(index_shape) - 1 + index = broadcast(index, src, dim) + if out is None: + out_size = src.shape + if dim_size is not None: + out_size[dim] = dim_size + elif index.numel() == 0: + out_size[dim] = 0 + else: + tmp = index.index_select( + index=to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + out_size[dim] = int(tmp) + 1 + if src.numel() == 0: + return zeros(out_size, dtype=src.dtype) + else: + if src.numel() == 0: + return out + out_size = out.shape + init = out.clone() + tmp = zeros(out_size, dtype=src.dtype) + src_fatten = transform_3d(src, dim) + index_flatten = transform_3d(index, dim) + out_flatten = transform_3d(tmp, dim) + for i in range(src_fatten.shape[0]): + for j in range(src_fatten.shape[-1]): + result = segment_sum(src_fatten[i, :, j], index_flatten[i, :, j]) + if out_size[dim] >= len(result): + out_flatten[i, : len(result), j] = result + else: + out_flatten[i, :, j] = result[: out_size[dim]] + res = out_flatten.reshape(out_size) + if out is None: + return res + else: + res = res + init + assign(res, out) + return out + + +def segment_add_coo( + src: paddle.Tensor, + index: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. The reduction method is sum. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The sorted indices of elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by sum reduction method. + """ + return segment_sum_coo(src, index, out, dim_size) + + +def segment_mean_coo( + src: paddle.Tensor, + index: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. The reduction method is mean. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The sorted indices of elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + paddle.Tensor, the reduced tensor by mean reduction method. + """ + index_shape = index.shape + dim = len(index_shape) - 1 + index = broadcast(index, src, dim) + if out is None: + out_size = src.shape + if dim_size is not None: + out_size[dim] = dim_size + elif index.numel() == 0: + out_size[dim] = 0 + else: + tmp = index.index_select( + index=to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + out_size[dim] = int(tmp) + 1 + out = zeros(out_size, dtype=src.dtype) + else: + out_size = out.shape + if src.numel() == 0: + return out + src_fatten = transform_3d(src, dim) + index_flatten = transform_3d(index, dim) + out_flatten = transform_3d(out, dim) + for i in range(src_fatten.shape[0]): + for j in range(src_fatten.shape[-1]): + result = segment_mean(src_fatten[i, :, j], index_flatten[i, :, j]) + if out_size[dim] >= len(result): + out_flatten[i, : len(result), j] = result + else: + out_flatten[i, :, j] = result[: out_size[dim]] + assign(out_flatten.reshape(out_size), out) + return out + + +def segment_min_coo( + src: paddle.Tensor, + index: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. The reduction method is min. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The sorted indices of elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor], the reduced min tensor and arg_min tensor. + """ + src_shape = src.shape + index_shape = index.shape + dim = len(index_shape) - 1 + # broadcast indptr to src + index_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + index = index.reshape(index_shape) + else: + index = index.expand(index_shape) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + tmp = index.index_select( + index=to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + size[dim] = int(tmp) + 1 + + if out is None: + if src.numel() == 0: + return ( + zeros(size, dtype=src.dtype), + full(size, src.shape[dim], index.dtype), + ) + return custom_segment_coo_min_max(src, index, None, size, "min") + else: + if src.numel() == 0: + return (out, full(size, src.shape[dim], index.dtype)) + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result, arg_result = custom_segment_coo_min_max( + src, index, out, out.shape, "min" + ) + assign(result, out) + return out, arg_result + + +def segment_max_coo( + src: paddle.Tensor, + index: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. The reduction method is max. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The sorted indices of elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor], the reduced max tensor and arg_max tensor. + """ + src_shape = src.shape + index_shape = index.shape + dim = len(index_shape) - 1 + # broadcast indptr to src + index_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + index = index.reshape(index_shape) + else: + index = index.expand(index_shape) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + tmp = index.index_select( + index=to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + size[dim] = int(tmp) + 1 + + if out is None: + if src.numel() == 0: + return ( + zeros(size, dtype=src.dtype), + full(size, src.shape[dim], index.dtype), + ) + return custom_segment_coo_min_max(src, index, None, size, "max") + else: + if src.numel() == 0: + return (out, full(size, src.shape[dim], index.dtype)) + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result, arg_result = custom_segment_coo_min_max( + src, index, out, out.shape, "max" + ) + assign(result, out) + return out, arg_result + + +def segment_coo( + src: paddle.Tensor, + index: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> paddle.Tensor: + r"""Reduces all values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. + + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `index.dim() - 1` and by the + corresponding value in `index` for dimension `index.dim() - 1`. + The applied reduction is defined via the `reduce` argument. + + Formally, if `src` and `index` are :math:`n`-dimensional and + :math:`m`-dimensional tensors with + size :math:`(x_1, ..., x_{m-1}, x_m, x_{m+1}, ..., x_n)` and + :math:`(x_1, ..., x_{m-1}, x_m)`, respectively, then `out` must be an + :math:`n`-dimensional tensor with size + :math:`(x_1, ..., x_{m-1}, y, x_{m+1}, ..., x_n)`. + Moreover, the values of `index` must be between :math:`0` and + :math:`y - 1` in ascending order. + The `index` tensor supports broadcasting in case its dimensions do + not match with `src`. + + For one-dimensional tensors with `reduce="sum"`, the operation + computes + + $$ + \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j + $$ + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Notes: + In contrast to :meth:`scatter`, this method expects values in `index` + **to be sorted** along dimension `index.dim() - 1`. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The sorted indices of elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + dim_size (int|None, optional): If `out` is not given, automatically create output + with size `dim_size` at dimension `dim`. If `dim_size` is not given, + a minimal sized output tensor according to `index.max() + 1` is returned. + Default is None. + reduce (str, optional): The reduce operation (`"sum"`, `"add"`, `"mean"`, + `"min"` or `"max"`). Default is `"sum"`. + + Returns: + paddle.Tensor, the reduced tensor. + + Examples: + >>> from paddle_scatter import segment_coo + + >>> src = paddle.randn([10, 6, 64]) + >>> index = paddle.to_tensor([0, 0, 1, 1, 1, 2]) + >>> index = index.view(1, -1) # Broadcasting in the first and last dim. + + >>> out = segment_coo(src, index, reduce="sum") + + >>> print(out.shape) + [10, 3, 64] + """ + if reduce == "sum" or reduce == "add": + return segment_sum_coo(src, index, out, dim_size) + elif reduce == "mean": + return segment_mean_coo(src, index, out, dim_size) + elif reduce == "min": + return segment_min_coo(src, index, out, dim_size)[0] + elif reduce == "max": + return segment_max_coo(src, index, out, dim_size)[0] + else: + raise ValueError + + +def gather_coo( + src: paddle.Tensor, index: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + r"""Gather values from the `src` tensor into `out` at the + indices specified in the `index` tensor along the last dimension of + `index`. + + Formally, if `src` and `index` are :math:`n`-dimensional and + :math:`m`-dimensional tensors with + size :math:`(x_1, ..., x_{m-1}, x_m, x_{m+1}, ..., x_n)` and + :math:`(x_1, ..., x_{m-1}, y)`, respectively, then `out` must be an + :math:`n`-dimensional tensor with size + :math:`(x_1, ..., x_{m-1}, y, x_{m+1}, ..., x_n)`. + Moreover, the elements of `index` must be between :math:`0` and + :math:`x_m - 1` in ascending order. + The `index` tensor supports broadcasting in case its dimensions do + not match with `src`. + + $$ + \mathrm{out}_{i} = \mathrm{src}_{\mathrm{index}_{i}} + $$ + + where :math:`i` is the index at the last dimension of `index`. + + Args: + src (paddle.Tensor): The source tensor. + index (paddle.Tensor): The indices of elements to gather. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + paddle.Tensor, the gathered tensor. + + Examples: + >>> from paddle_scatter import gather_coo + + >>> src = paddle.to_tensor([1, 2, 3, 4]) + >>> index = paddle.to_tensor([0, 0, 1, 1, 1, 3]) + + >>> out = gather_coo(src, index) + + >>> print(out) + Tensor(shape=[6], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 1, 2, 2, 2, 4]) + """ + index_shape = index.shape + dim = len(index_shape) - 1 + src_shape = src.shape + + # broadcast index's dimension to the same as src's + for _ in range(len(src_shape) - len(index_shape)): + index = index.unsqueeze(-1) + new_index_shape = src_shape.copy() + new_index_shape[dim] = index_shape[dim] + if src.numel() == 0: + index = index.reshape(new_index_shape) + else: + index = index.expand(new_index_shape) + + if out is None: + out_size = src_shape + if index.numel() == 0: + out_size[dim] = 0 + else: + out_size[dim] = index.shape[dim] + out = zeros(out_size, dtype=src.dtype) + else: + out_size = out.shape + if src.numel() == 0: + return out + + result = take_along_axis(src, index, dim, broadcast=False) + if out_size[dim] > result.shape[dim]: + padding = [0] * 2 * len(out_size) + padding[2 * dim + 1] = out_size[dim] - result.shape[dim] + result = pad(result, padding, mode="constant", value=0) + elif out_size[dim] < result.shape[dim]: + result = slice(result, [dim], 0, out_size[dim]) + return assign(result, out) diff --git a/jointContribution/paddle_scatter/segment_csr.py b/jointContribution/paddle_scatter/segment_csr.py new file mode 100644 index 000000000..6ecfbeb92 --- /dev/null +++ b/jointContribution/paddle_scatter/segment_csr.py @@ -0,0 +1,454 @@ +from typing import Optional +from typing import Tuple + +import paddle +from paddle import arange +from paddle import assign +from paddle import full +from paddle import repeat_interleave +from paddle import zeros +from paddle.geometric import segment_mean +from paddle.geometric import segment_sum +from paddle_scatter_min_max_ops import custom_segment_csr_min_max + +from .utils import transform_2d +from .utils import transform_3d + + +def segment_sum_csr( + src: paddle.Tensor, indptr: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + r""" + Reduces all values from the `src` tensor into `out` within the + ranges specified in the `indptr` tensor along the last dimension of + `indptr`. + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `indptr.dim() - 1` and by the + corresponding range index in `indptr` for dimension + `indptr.dim() - 1`. + The reduction method is sum. + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + paddle.Tensor, the reduced tensor by sum reduction method. + """ + indptr_shape = indptr.shape + src_shape = src.shape + dim = len(indptr_shape) - 1 + # broadcast indptr to src + indptr_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + indptr = indptr.reshape(indptr_shape) + else: + indptr = indptr.expand(indptr_shape) + + num_seg = indptr_shape[dim] - 1 + if out is None: + out_size = src_shape + if indptr.numel() == 0: + out_size[dim] = 0 + else: + out_size[dim] = num_seg + if src.numel() == 0: + return zeros(out_size, dtype=src.dtype) + else: + assert ( + out.shape[dim] == num_seg + ), "The (size of indptr at last dimension) must be\ + equal to the (size of out at the same dimension) + 1." + if src.numel() == 0: + return out + out_size = out.shape + tmp = zeros(out_size, dtype=src.dtype) + + repeats = indptr.diff(n=1, axis=dim) + assert ( + repeats.sum(axis=dim) == src.shape[dim] + ).all(), "The length of specified index by indptr shoud be\ + equal to the size of src at last dimension of indptr." + src_flatten = transform_3d(src, dim) + out_flatten = transform_3d(tmp, dim) + repeats_flatten = transform_2d(repeats, dim) + src_dim_indices = arange(num_seg) + for i in range(src_flatten.shape[0]): + for j in range(src_flatten.shape[-1]): + belongs_to = repeat_interleave(src_dim_indices, repeats_flatten[i], 0) + result = segment_sum(src_flatten[i, :, j], belongs_to) + if out_size[dim] >= len(result): + out_flatten[i, : len(result), j] = result + else: + out_flatten[i, :, j] = result[: out_size[dim]] + if out is None: + return out_flatten.reshape(out_size) + else: + assign(out_flatten.reshape(out_size), out) + return out + + +def segment_add_csr( + src: paddle.Tensor, indptr: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + r""" + Reduces all values from the `src` tensor into `out` within the + ranges specified in the `indptr` tensor along the last dimension of + `indptr`. + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `indptr.dim() - 1` and by the + corresponding range index in `indptr` for dimension + `indptr.dim() - 1`. + The reduction method is sum. + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + paddle.Tensor, the reduced tensor by sum reduction method. + """ + return segment_sum_csr(src, indptr, out) + + +def segment_mean_csr( + src: paddle.Tensor, indptr: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + r""" + Reduces all values from the `src` tensor into `out` within the + ranges specified in the `indptr` tensor along the last dimension of + `indptr`. + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `indptr.dim() - 1` and by the + corresponding range index in `indptr` for dimension + `indptr.dim() - 1`. + The reduction method is mean. + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + paddle.Tensor, the reduced tensor by mean reduction method. + """ + indptr_shape = indptr.shape + src_shape = src.shape + dim = len(indptr_shape) - 1 + # broadcast indptr to src + indptr_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + indptr = indptr.reshape(indptr_shape) + else: + indptr = indptr.expand(indptr_shape) + + num_seg = indptr_shape[dim] - 1 + if out is None: + out_size = src_shape + if indptr.numel() == 0: + out_size[dim] = 0 + else: + out_size[dim] = num_seg + if src.numel() == 0: + return zeros(out_size, dtype=src.dtype) + else: + assert ( + out.shape[dim] == num_seg + ), "The (size of indptr at last dimension) must be\ + equal to the (size of out at the same dimension) + 1." + if src.numel() == 0: + return out + out_size = out.shape + tmp = zeros(out_size, dtype=src.dtype) + + repeats = indptr.diff(n=1, axis=dim) + assert ( + repeats.sum(axis=dim) == src.shape[dim] + ).all(), "The length of specified index by indptr shoud be\ + equal to the size of src at last dimension of indptr." + src_flatten = transform_3d(src, dim) + out_flatten = transform_3d(tmp, dim) + repeats_flatten = transform_2d(repeats, dim) + src_dim_indices = arange(num_seg) + for i in range(src_flatten.shape[0]): + for j in range(src_flatten.shape[-1]): + belongs_to = repeat_interleave(src_dim_indices, repeats_flatten[i], 0) + result = segment_mean(src_flatten[i, :, j], belongs_to) + if out_size[dim] >= len(result): + out_flatten[i, : len(result), j] = result + else: + out_flatten[i, :, j] = result[: out_size[dim]] + if out is None: + return out_flatten.reshape(out_size) + else: + assign(out_flatten.reshape(out_size), out) + return out + + +def segment_min_csr( + src: paddle.Tensor, indptr: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> Tuple[paddle.Tensor, paddle.Tensor]: + r""" + Reduces all values from the `src` tensor into `out` within the + ranges specified in the `indptr` tensor along the last dimension of + `indptr`. + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `indptr.dim() - 1` and by the + corresponding range index in `indptr` for dimension + `indptr.dim() - 1`. + The reduction method is min. + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor], the reduced min tensor and arg_min tensor. + """ + indptr_shape = indptr.shape + src_shape = src.shape + dim = len(indptr_shape) - 1 + # broadcast indptr to src + indptr_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + indptr = indptr.reshape(indptr_shape) + else: + indptr = indptr.expand(indptr_shape) + + if out is None: + size = src.shape + size[dim] = max(indptr_shape[dim] - 1, 0) + if src.numel() == 0: + return ( + zeros(size, dtype=src.dtype), + full(size, src.shape[dim], indptr.dtype), + ) + return custom_segment_csr_min_max(src, indptr, size, "min") + else: + if src.numel() == 0: + return (out, full(size, src.shape[dim], indptr.dtype)) + result, arg_result = custom_segment_csr_min_max(src, indptr, out.shape, "min") + assign(result, out) + return out, arg_result + + +def segment_max_csr( + src: paddle.Tensor, indptr: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> Tuple[paddle.Tensor, paddle.Tensor]: + r""" + Reduces all values from the `src` tensor into `out` within the + ranges specified in the `indptr` tensor along the last dimension of + `indptr`. + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `indptr.dim() - 1` and by the + corresponding range index in `indptr` for dimension + `indptr.dim() - 1`. + The reduction method is max. + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor], the reduced max tensor and arg_max tensor. + """ + indptr_shape = indptr.shape + src_shape = src.shape + dim = len(indptr_shape) - 1 + # broadcast indptr to src + indptr_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + indptr = indptr.reshape(indptr_shape) + else: + indptr = indptr.expand(indptr_shape) + + if out is None: + size = src.shape + size[dim] = max(indptr_shape[dim] - 1, 0) + if src.numel() == 0: + return ( + zeros(size, dtype=src.dtype), + full(size, src.shape[dim], indptr.dtype), + ) + return custom_segment_csr_min_max(src, indptr, size, "max") + else: + if src.numel() == 0: + return (out, full(size, src.shape[dim], indptr.dtype)) + result, arg_result = custom_segment_csr_min_max(src, indptr, out.shape, "max") + assign(result, out) + return out, arg_result + + +def segment_csr( + src: paddle.Tensor, + indptr: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + reduce: str = "sum", +) -> paddle.Tensor: + r""" + Reduces all values from the `src` tensor into `out` within the + ranges specified in the `indptr` tensor along the last dimension of + `indptr`. + For each value in `src`, its output index is specified by its index + in `src` for dimensions outside of `indptr.dim() - 1` and by the + corresponding range index in `indptr` for dimension + `indptr.dim() - 1`. + The applied reduction is defined via the `reduce` argument. + + Formally, if `src` and `indptr` are :math:`n`-dimensional and + :math:`m`-dimensional tensors with + size :math:`(x_1, ..., x_{m-1}, x_m, x_{m+1}, ..., x_n)` and + :math:`(x_1, ..., x_{m-1}, y)`, respectively, then `out` must be an + :math:`n`-dimensional tensor with size + :math:`(x_1, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_n)`. + Moreover, the values of `indptr` must be between :math:`0` and + :math:`x_m` in ascending order. + The `indptr` tensor supports broadcasting in case its dimensions do + not match with `src`. + + For one-dimensional tensors with `reduce="sum"`, the operation + computes + + $$ + \mathrm{out}_i = + \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j. + $$ + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + reduce (str, optional): The reduce operation (`"sum"`, `"add"`, `"mean"`, + `"min"` or `"max"`). Default is `"sum"`. + + Returns: + paddle.Tensor, the reduced tensor. + + Examples: + >>> from paddle_scatter import segment_csr + + >>> src = paddle.randn([10, 6, 64]) + >>> indptr = paddle.tensor([0, 2, 5, 6]) + >>> indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. + + >>> out = segment_csr(src, indptr, reduce="sum") + + >>> print(out.shape) + [10, 3, 64] + """ + if reduce == "sum" or reduce == "add": + return segment_sum_csr(src, indptr, out) + elif reduce == "mean": + return segment_mean_csr(src, indptr, out) + elif reduce == "min": + return segment_min_csr(src, indptr, out)[0] + elif reduce == "max": + return segment_max_csr(src, indptr, out)[0] + else: + raise ValueError + + +def gather_csr( + src: paddle.Tensor, indptr: paddle.Tensor, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + r"""Gather values from the `src` tensor into `out` at the + indices specified within the ranges specified in the `indptr` + tensor along the last dimension of `indptr`. + + Formally, if `src` and `indptr` are :math:`n`-dimensional and + :math:`m`-dimensional tensors with + size :math:`(x_1, ..., x_{m-1}, x_m, x_{m+1}, ..., x_n)` and + :math:`(x_1, ..., x_{m-1}, y)` (y = x_m + 1), respectively, + then `out` must be an :math:`n`-dimensional tensor with size + :math:`(x_1, ..., x_{m-1}, k, x_{m+1}, ..., x_n)`, where :math:`k` + is the number of segments specified by `indptr`. + Moreover, the values of `indptr` must be between :math:`0` and + :math:`x_m` in ascending order. + The `indptr` tensor supports broadcasting in case its dimensions do + not match with `src`. + + $$ + \mathrm{out}_i = \mathrm{src}[indptr[k]], + k = indptr[(indptr - i <= 0)][-1] + $$ + + where :math:`i` is the index at the last dimension of `index`. + + Args: + src (paddle.Tensor): The source tensor. + indptr (paddle.Tensor): The index pointers between elements to segment. + The number of dimensions of `index` needs to be less than or + equal to `src`. + out (paddle.Tensor|None, optional): The destination tensor. Default is None. + + Returns: + paddle.Tensor, the gathered tensor. + + Examples: + >>> from paddle_scatter import gather_csr + + >>> src = paddle.to_tensor([1, 2, 3, 4]) + >>> indptr = paddle.to_tensor([0, 2, 5, 5, 6]) + + >>> out = gather_csr(src, indptr) + + >>> print(out) + Tensor(shape=[6], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 1, 2, 2, 2, 4]) + """ + indptr_shape = indptr.shape + src_shape = src.shape + dim = len(indptr_shape) - 1 + # broadcast indptr to src + indptr_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + indptr = indptr.reshape(indptr_shape) + else: + indptr = indptr.expand(indptr_shape) + assert ( + src_shape[dim] == indptr_shape[dim] - 1 + ), "The (size of indptr at last dimension) must be equal to\ + the (size of src at the same dimension) + 1." + if out is None: + out_size = src_shape + if indptr.numel() == 0: + out_size[dim] = 0 + else: + # refer to the original design in source cpp code + out_size[dim] = indptr.flatten()[-1] + out = zeros(out_size, dtype=src.dtype) + else: + out_size = out.shape + if src.numel() == 0: + return out + + repeats = indptr.diff(n=1, axis=dim) + src_flatten = transform_3d(src, dim) + out_flatten = transform_3d(out, dim) + repeats_flatten = transform_2d(repeats, dim) + for i in range(src_flatten.shape[0]): + for j in range(src_flatten.shape[-1]): + result = repeat_interleave(src_flatten[i, :, j], repeats_flatten[i], 0) + repeat_sum = repeats_flatten[i].sum() + if out_size[dim] >= repeat_sum: + out_flatten[i, :repeat_sum, j] = result[:repeat_sum] + else: + out_flatten[i, :, j] = result[: out_size[dim]] + assign(out_flatten.reshape(out_size), out) + return out diff --git a/jointContribution/paddle_scatter/setup.py b/jointContribution/paddle_scatter/setup.py new file mode 100644 index 000000000..d78a28133 --- /dev/null +++ b/jointContribution/paddle_scatter/setup.py @@ -0,0 +1,41 @@ +import os + +import paddle +from paddle.utils.cpp_extension import CppExtension +from paddle.utils.cpp_extension import CUDAExtension +from paddle.utils.cpp_extension import setup + + +def get_sources(): + csrc_dir_path = os.path.join(os.path.dirname(__file__), "csrc") + cpp_files = [] + for item in os.listdir(csrc_dir_path): + if paddle.core.is_compiled_with_cuda(): + if item.endswith(".cc") or item.endswith(".cu"): + cpp_files.append(os.path.join(csrc_dir_path, item)) + else: + if item.endswith(".cc"): + cpp_files.append(os.path.join(csrc_dir_path, item)) + return csrc_dir_path, cpp_files + + +def get_extensions(): + src = get_sources() + Extension = CUDAExtension if paddle.core.is_compiled_with_cuda() else CppExtension + ext_modules = [ + Extension( + sources=src[1], + include_dirs=src[0], + ) + ] + return ext_modules + + +setup( + name="paddle_scatter_min_max_ops", + version="1.0", + author="NKNaN", + url="https://github.com/PaddlePaddle/PaddleScience/jointContribution/paddle_scatter", + description="Paddle extension of scatter and segment operators with min and max reduction methods", + ext_modules=get_extensions(), +) diff --git a/jointContribution/paddle_scatter/testing.py b/jointContribution/paddle_scatter/testing.py new file mode 100644 index 000000000..dfe23294c --- /dev/null +++ b/jointContribution/paddle_scatter/testing.py @@ -0,0 +1,23 @@ +from typing import Any + +import paddle + +reductions = ["sum", "add", "mean", "min", "max"] + +dtypes = [ + # paddle.float16, paddle.bfloat16, + paddle.float32, + paddle.float64, + paddle.int32, + paddle.int64, +] +ind_dtypes = [paddle.int32, paddle.int64] +grad_dtypes = [paddle.float32, paddle.float64] + +places = ["cpu"] +if paddle.core.is_compiled_with_cuda(): + places.append("gpu") + + +def tensor(x: Any, dtype: paddle.dtype): + return None if x is None else paddle.to_tensor(x).astype(dtype) diff --git a/jointContribution/paddle_scatter/tests/composite/test_logsumexp.py b/jointContribution/paddle_scatter/tests/composite/test_logsumexp.py new file mode 100644 index 000000000..666924018 --- /dev/null +++ b/jointContribution/paddle_scatter/tests/composite/test_logsumexp.py @@ -0,0 +1,39 @@ +import paddle +from paddle_scatter import scatter_logsumexp + + +def test_logsumexp(): + inputs = paddle.to_tensor( + [ + 0.5, + 0.5, + 0.0, + -2.1, + 3.2, + 7.0, + -1.0, + -100.0, + ] + ) + inputs.stop_gradient = False + index = paddle.to_tensor([0, 0, 1, 1, 1, 2, 4, 4]) + splits = [2, 3, 1, 0, 2] + + outputs = scatter_logsumexp(inputs, index) + + for src, out in zip(inputs.split(splits), outputs.unbind()): + if src.numel() > 0: + assert out.numpy() == paddle.logsumexp(src, axis=0).numpy() + else: + assert out.item() == 0.0 + + outputs.backward(paddle.randn(outputs.shape, outputs.dtype)) + + +def test_logsumexp_out(): + src = paddle.to_tensor([-1.0, -50.0]) + index = paddle.to_tensor([0, 0]) + out = paddle.to_tensor([-10.0, -10.0]) + + scatter_logsumexp(src=src, index=index, out=out) + assert out.allclose(paddle.to_tensor([-0.9999, -10.0]), atol=1e-4) diff --git a/jointContribution/paddle_scatter/tests/composite/test_softmax.py b/jointContribution/paddle_scatter/tests/composite/test_softmax.py new file mode 100644 index 000000000..20c2c6bb8 --- /dev/null +++ b/jointContribution/paddle_scatter/tests/composite/test_softmax.py @@ -0,0 +1,51 @@ +import paddle +from paddle_scatter import scatter_log_softmax +from paddle_scatter import scatter_softmax + + +def test_softmax(): + src = paddle.to_tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float("-inf")]) + src.stop_gradient = False + index = paddle.to_tensor([0, 1, 0, 1, 1, 2, 4, 4]) + + out = scatter_softmax(src, index) + + out0 = paddle.nn.functional.softmax(paddle.to_tensor([0.2, 0.2]), axis=-1) + out1 = paddle.nn.functional.softmax(paddle.to_tensor([0, -2.1, 3.2]), axis=-1) + out2 = paddle.nn.functional.softmax( + paddle.to_tensor([7], dtype=paddle.float32), axis=-1 + ) + out4 = paddle.nn.functional.softmax(paddle.to_tensor([-1, float("-inf")]), axis=-1) + + expected = paddle.stack( + [out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]], axis=0 + ) + + assert paddle.allclose(out, expected) + + out.backward(paddle.randn(out.shape, out.dtype)) + + +def test_log_softmax(): + src = paddle.to_tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float("-inf")]) + src.stop_gradient = False + index = paddle.to_tensor([0, 1, 0, 1, 1, 2, 4, 4]) + + out = scatter_log_softmax(src, index) + + out0 = paddle.nn.functional.log_softmax(paddle.to_tensor([0.2, 0.2]), axis=-1) + out1 = paddle.nn.functional.log_softmax(paddle.to_tensor([0, -2.1, 3.2]), axis=-1) + out2 = paddle.nn.functional.log_softmax( + paddle.to_tensor([7], dtype=paddle.float32), axis=-1 + ) + out4 = paddle.nn.functional.log_softmax( + paddle.to_tensor([-1, float("-inf")]), axis=-1 + ) + + expected = paddle.stack( + [out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]], axis=0 + ) + + assert paddle.allclose(out, expected) + + out.backward(paddle.randn(out.shape, out.dtype)) diff --git a/jointContribution/paddle_scatter/tests/composite/test_std.py b/jointContribution/paddle_scatter/tests/composite/test_std.py new file mode 100644 index 000000000..00ffb364b --- /dev/null +++ b/jointContribution/paddle_scatter/tests/composite/test_std.py @@ -0,0 +1,27 @@ +import paddle +from paddle_scatter import scatter_std + + +def test_std(): + src = paddle.to_tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=paddle.float32) + src.stop_gradient = False + index = paddle.to_tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=paddle.int64) + + out = scatter_std(src, index, dim=-1, unbiased=True) + std = src.std(axis=-1, unbiased=True)[0] + expected = paddle.to_tensor([[std, 0], [0, std]]) + assert paddle.allclose(out, expected) + + out.backward(paddle.randn(out.shape, out.dtype)) + + +def test_std_out(): + src = paddle.to_tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=paddle.float32) + index = paddle.to_tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=paddle.int64) + out = paddle.to_tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + + scatter_std(src, index, dim=-1, out=out, unbiased=True) + std = src.std(axis=-1, unbiased=True)[0] + expected = paddle.to_tensor([[std, 0, 0], [0, std, 0]]) + + assert paddle.allclose(out, expected) diff --git a/jointContribution/paddle_scatter/tests/test_broadcasting.py b/jointContribution/paddle_scatter/tests/test_broadcasting.py new file mode 100644 index 000000000..8bf7065b1 --- /dev/null +++ b/jointContribution/paddle_scatter/tests/test_broadcasting.py @@ -0,0 +1,29 @@ +from itertools import product + +import paddle +import pytest +from paddle_scatter import scatter +from paddle_scatter.testing import places +from paddle_scatter.testing import reductions + + +@pytest.mark.parametrize("reduce,place", product(reductions, places)) +def test_broadcasting(reduce, place): + paddle.set_device(place) + + B, C, H, W = (4, 3, 8, 8) + + src = paddle.randn((B, C, H, W)) + index = paddle.randint(0, H, (H,)).astype(paddle.int64) + out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) + assert out.shape == [B, C, H, W] + + src = paddle.randn((B, C, H, W)) + index = paddle.randint(0, H, (B, 1, H, W)).astype(paddle.int64) + out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) + assert out.shape == [B, C, H, W] + + src = paddle.randn((B, C, H, W)) + index = paddle.randint(0, H, (H,)).astype(paddle.int64) + out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) + assert out.shape == [B, C, H, W] diff --git a/jointContribution/paddle_scatter/tests/test_gather.py b/jointContribution/paddle_scatter/tests/test_gather.py new file mode 100644 index 000000000..fbbb4df42 --- /dev/null +++ b/jointContribution/paddle_scatter/tests/test_gather.py @@ -0,0 +1,145 @@ +from itertools import product + +import paddle +import pytest +from paddle_scatter import gather_coo +from paddle_scatter import gather_csr +from paddle_scatter.testing import dtypes +from paddle_scatter.testing import ind_dtypes +from paddle_scatter.testing import places +from paddle_scatter.testing import tensor + +tests = [ + { + "src": [1, 2, 3, 4], + "index": [0, 0, 1, 1, 1, 3], + "indptr": [0, 2, 5, 5, 6], + "expected": [1, 1, 2, 2, 2, 4], + "expected_grad": [2, 3, 0, 1], + }, + { + "src": [[1, 2], [3, 4], [5, 6], [7, 8]], + "index": [0, 0, 1, 1, 1, 3], + "indptr": [0, 2, 5, 5, 6], + "expected": [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4], [7, 8]], + "expected_grad": [[2, 2], [3, 3], [0, 0], [1, 1]], + }, + { + "src": [[1, 3, 5, 7], [2, 4, 6, 8]], + "index": [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], + "indptr": [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], + "expected": [[1, 1, 3, 3, 3, 7], [2, 2, 2, 4, 4, 6]], + "expected_grad": [[2, 3, 0, 1], [3, 2, 1, 0]], + }, + { + "src": [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], + "index": [[0, 0, 1], [0, 2, 2]], + "indptr": [[0, 2, 3, 3], [0, 1, 1, 3]], + "expected": [[[1, 2], [1, 2], [3, 4]], [[7, 9], [12, 13], [12, 13]]], + "expected_grad": [[[2, 2], [1, 1], [0, 0]], [[1, 1], [0, 0], [2, 2]]], + }, + { + "src": [[1], [2]], + "index": [[0, 0], [0, 0]], + "indptr": [[0, 2], [0, 2]], + "expected": [[1, 1], [2, 2]], + "expected_grad": [[2], [2]], + }, + { + "src": [[[1, 1]], [[2, 2]]], + "index": [[0, 0], [0, 0]], + "indptr": [[0, 2], [0, 2]], + "expected": [[[1, 1], [1, 1]], [[2, 2], [2, 2]]], + "expected_grad": [[[2, 2]], [[2, 2]]], + }, +] + + +@pytest.mark.parametrize( + "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) +) +def test_forward(test, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + out = gather_csr(src, indptr) + assert paddle.all(out == expected) + + out = gather_coo(src, index) + assert paddle.all(out == expected) + + +@pytest.mark.parametrize("test,place", product(tests, places)) +def test_backward(test, place): + paddle.set_device(place) + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + exp_grad = tensor(test["expected_grad"], paddle.float64) + + src = tensor(test["src"], paddle.float64) + src.stop_gradient = False + out = gather_csr(src, indptr) + out.backward() + assert paddle.all(src.grad == exp_grad) + + src = tensor(test["src"], paddle.float64) + src.stop_gradient = False + out = gather_coo(src, index) + out.backward() + assert paddle.all(src.grad == exp_grad) + + +@pytest.mark.parametrize( + "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) +) +def test_out(test, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + size = src.shape + size[index.dim() - 1] = index.shape[-1] + out = paddle.full(size, -2).astype(dtype) + + gather_csr(src, indptr, out) + assert paddle.all(out == expected) + + out.fill_(-2) + + gather_coo(src, index, out) + assert paddle.all(out == expected) + + +@pytest.mark.parametrize( + "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) +) +def test_non_contiguous(test, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + if indptr.dim() > 1: + shape = list(range(indptr.dim())) + shape[0], shape[1] = shape[1], shape[0] + indptr = indptr.transpose(shape).contiguous().transpose(shape) + + out = gather_csr(src, indptr) + assert paddle.all(out == expected) + + out = gather_coo(src, index) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/tests/test_multi_gpu.py b/jointContribution/paddle_scatter/tests/test_multi_gpu.py new file mode 100644 index 000000000..e8dd74b6b --- /dev/null +++ b/jointContribution/paddle_scatter/tests/test_multi_gpu.py @@ -0,0 +1,43 @@ +from itertools import product + +import paddle +import paddle_scatter +import pytest +from paddle_scatter.testing import dtypes +from paddle_scatter.testing import reductions +from paddle_scatter.testing import tensor + +tests = [ + { + "src": [1, 2, 3, 4, 5, 6], + "index": [0, 0, 1, 1, 1, 3], + "indptr": [0, 2, 5, 5, 6], + "dim": 0, + "sum": [3, 12, 0, 6], + "add": [3, 12, 0, 6], + "mean": [1.5, 4, 0, 6], + "min": [1, 3, 0, 6], + "max": [2, 5, 0, 6], + }, +] + + +@pytest.mark.skipif(not paddle.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(paddle.cuda.device_count() < 2, reason="No multiple GPUS") +@pytest.mark.parametrize("test,reduce,dtype", product(tests, reductions, dtypes)) +def test_forward(test, reduce, dtype): + paddle.set_device("gpu:1") + src = tensor(test["src"], dtype) + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + out = paddle_scatter.scatter(src, index, dim, reduce=reduce) + assert paddle.all(out == expected) + + out = paddle_scatter.segment_coo(src, index, reduce=reduce) + assert paddle.all(out == expected) + + out = paddle_scatter.segment_csr(src, indptr, reduce=reduce) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/tests/test_scatter.py b/jointContribution/paddle_scatter/tests/test_scatter.py new file mode 100644 index 000000000..51c519222 --- /dev/null +++ b/jointContribution/paddle_scatter/tests/test_scatter.py @@ -0,0 +1,296 @@ +from itertools import product + +import numpy as np +import paddle +import paddle_scatter +import pytest +from paddle_scatter.testing import dtypes +from paddle_scatter.testing import ind_dtypes +from paddle_scatter.testing import places +from paddle_scatter.testing import reductions +from paddle_scatter.testing import tensor + +reductions = reductions + ["mul"] + +tests = [ + { + "src": [1, 3, 2, 4, 5, 6], + "index": [0, 1, 0, 1, 1, 3], + "dim": -1, + "sum": [3, 12, 0, 6], + "add": [3, 12, 0, 6], + "mul": [2, 60, 1, 6], + "mean": [1.5, 4, 0, 6], + "min": [1, 3, 0, 6], + "arg_min": [0, 1, 6, 5], + "max": [2, 5, 0, 6], + "arg_max": [2, 4, 6, 5], + "sum_grad": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "add_grad": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "mean_grad": [0.5000, 0.33333333, 0.5000, 0.33333333, 0.33333333, 1.0000], + "min_grad": [1.0, 1.0, 0.0, 0.0, 0.0, 1.0], + "max_grad": [0.0, 0.0, 1.0, 0.0, 1.0, 1.0], + "mul_grad": [2.0, 20.0, 1.0, 15.0, 12.0, 1.0], + }, + { + "src": [[1, 2], [5, 6], [3, 4], [7, 8], [9, 10], [11, 12]], + "index": [0, 1, 0, 1, 1, 3], + "dim": 0, + "sum": [[4, 6], [21, 24], [0, 0], [11, 12]], + "add": [[4, 6], [21, 24], [0, 0], [11, 12]], + "mul": [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]], + "mean": [[2, 3], [7, 8], [0, 0], [11, 12]], + "min": [[1, 2], [5, 6], [0, 0], [11, 12]], + "arg_min": [[0, 0], [1, 1], [6, 6], [5, 5]], + "max": [[3, 4], [9, 10], [0, 0], [11, 12]], + "arg_max": [[2, 2], [4, 4], [6, 6], [5, 5]], + "sum_grad": [ + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ], + "add_grad": [ + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ], + "mean_grad": [ + [0.5000, 0.5000], + [0.33333333, 0.33333333], + [0.5000, 0.5000], + [0.33333333, 0.33333333], + [0.33333333, 0.33333333], + [1.0000, 1.0000], + ], + "min_grad": [ + [1.0, 1.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [1.0, 1.0], + ], + "max_grad": [ + [0.0, 0.0], + [0.0, 0.0], + [1.0, 1.0], + [0.0, 0.0], + [1.0, 1.0], + [1.0, 1.0], + ], + "mul_grad": [ + [3.0, 4.0], + [63.0, 80.0], + [1.0, 2.0], + [45.0, 60.0], + [35.0, 48.0], + [1.0, 1.0], + ], + }, + { + "src": [[1, 5, 3, 7, 9, 11], [2, 4, 8, 6, 10, 12]], + "index": [[0, 1, 0, 1, 1, 3], [0, 0, 1, 0, 1, 2]], + "dim": 1, + "sum": [[4, 21, 0, 11], [12, 18, 12, 0]], + "add": [[4, 21, 0, 11], [12, 18, 12, 0]], + "mul": [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]], + "mean": [[2, 7, 0, 11], [4, 9, 12, 0]], + "min": [[1, 5, 0, 11], [2, 8, 12, 0]], + "arg_min": [[0, 1, 6, 5], [0, 2, 5, 6]], + "max": [[3, 9, 0, 11], [6, 10, 12, 0]], + "arg_max": [[2, 4, 6, 5], [3, 4, 5, 6]], + "sum_grad": [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], + "add_grad": [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], + "mean_grad": [ + [0.5000, 0.33333333, 0.5000, 0.33333333, 0.33333333, 1.0000], + [0.33333333, 0.33333333, 0.5000, 0.33333333, 0.5000, 1.0000], + ], + "min_grad": [[1.0, 1.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0, 0.0, 1.0]], + "max_grad": [[0.0, 0.0, 1.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]], + "mul_grad": [ + [3.0, 63.0, 1.0, 45.0, 35.0, 1.0], + [24.0, 12.0, 10.0, 8.0, 8.0, 1.0], + ], + }, + { + "src": [[[1, 2], [5, 6], [3, 4]], [[10, 11], [7, 9], [12, 13]]], + "index": [[0, 1, 0], [2, 0, 2]], + "dim": 1, + "sum": [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + "add": [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + "mul": [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]], + "mean": [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], + "min": [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], + "arg_min": [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]], + "max": [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]], + "arg_max": [[[2, 2], [1, 1], [3, 3]], [[1, 1], [3, 3], [2, 2]]], + "sum_grad": [ + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + ], + "add_grad": [ + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + ], + "mean_grad": [ + [[0.5000, 0.5000], [1.0000, 1.0000], [0.5000, 0.5000]], + [[0.5000, 0.5000], [1.0000, 1.0000], [0.5000, 0.5000]], + ], + "min_grad": [ + [[1.0, 1.0], [1.0, 1.0], [0.0, 0.0]], + [[1.0, 1.0], [1.0, 1.0], [0.0, 0.0]], + ], + "max_grad": [ + [[0.0, 0.0], [1.0, 1.0], [1.0, 1.0]], + [[0.0, 0.0], [1.0, 1.0], [1.0, 1.0]], + ], + "mul_grad": [ + [[3.0, 4.0], [1.0, 1.0], [1.0, 2.0]], + [[12.0, 13.0], [1.0, 1.0], [10.0, 11.0]], + ], + }, + { + "src": [[1, 3], [2, 4]], + "index": [[0, 0], [0, 0]], + "dim": 1, + "sum": [[4], [6]], + "add": [[4], [6]], + "mul": [[3], [8]], + "mean": [[2], [3]], + "min": [[1], [2]], + "arg_min": [[0], [0]], + "max": [[3], [4]], + "arg_max": [[1], [1]], + "sum_grad": [[1.0, 1.0], [1.0, 1.0]], + "add_grad": [[1.0, 1.0], [1.0, 1.0]], + "mean_grad": [[0.5000, 0.5000], [0.5000, 0.5000]], + "min_grad": [[1.0, 0.0], [1.0, 0.0]], + "max_grad": [[0.0, 1.0], [0.0, 1.0]], + "mul_grad": [[3.0, 1.0], [4.0, 2.0]], + }, + { + "src": [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], + "index": [[0, 0], [0, 0]], + "dim": 1, + "sum": [[[4, 4]], [[6, 6]]], + "add": [[[4, 4]], [[6, 6]]], + "mul": [[[3, 3]], [[8, 8]]], + "mean": [[[2, 2]], [[3, 3]]], + "min": [[[1, 1]], [[2, 2]]], + "arg_min": [[[0, 0]], [[0, 0]]], + "max": [[[3, 3]], [[4, 4]]], + "arg_max": [[[1, 1]], [[1, 1]]], + "sum_grad": [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + "add_grad": [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + "mean_grad": [ + [[0.5000, 0.5000], [0.5000, 0.5000]], + [[0.5000, 0.5000], [0.5000, 0.5000]], + ], + "min_grad": [[[1.0, 1.0], [0.0, 0.0]], [[1.0, 1.0], [0.0, 0.0]]], + "max_grad": [[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]], + "mul_grad": [[[3.0, 3.0], [1.0, 1.0]], [[4.0, 4.0], [2.0, 2.0]]], + }, +] + + +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype,place", + product(tests, reductions, dtypes, ind_dtypes, places), +) +def test_forward(test, reduce, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + fn = getattr(paddle_scatter, "scatter_" + reduce) + out = fn(src, index, dim) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + print(arg_out) + print(arg_expected) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + +@pytest.mark.parametrize("test,reduce,place", product(tests, reductions, places)) +def test_backward(test, reduce, place): + paddle.set_device(place) + index = tensor(test["index"], paddle.int64) + dim = test["dim"] + exp_grad = tensor(test[f"{reduce}_grad"], paddle.float64) + + src = tensor(test["src"], paddle.float64) + src.stop_gradient = False + out = paddle_scatter.scatter(src, index, dim, None, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype,place", + product(tests, reductions, dtypes, ind_dtypes, places), +) +def test_out(test, reduce, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + out = paddle.full_like(expected, -2) + + getattr(paddle_scatter, "scatter_" + reduce)(src, index, dim, out) + + if reduce == "sum" or reduce == "add": + expected = expected - 2 + elif reduce == "mul": + expected = out # We can not really test this here. + elif reduce == "mean": + expected = out # We can not really test this here. + elif reduce == "min": + expected = expected.fill_(-2) + elif reduce == "max": + expected[expected == 0] = -2 + else: + raise ValueError + + np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) + + +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype,place", + product(tests, reductions, dtypes, ind_dtypes, places), +) +def test_non_contiguous(test, reduce, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + + out = getattr(paddle_scatter, "scatter_" + reduce)(src, index, dim) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/tests/test_segment.py b/jointContribution/paddle_scatter/tests/test_segment.py new file mode 100644 index 000000000..4acb788aa --- /dev/null +++ b/jointContribution/paddle_scatter/tests/test_segment.py @@ -0,0 +1,297 @@ +from itertools import product + +import numpy as np +import paddle +import paddle_scatter +import pytest +from paddle_scatter.testing import dtypes +from paddle_scatter.testing import ind_dtypes +from paddle_scatter.testing import places +from paddle_scatter.testing import reductions +from paddle_scatter.testing import tensor + +tests = [ + { + "src": [1, 2, 3, 4, 5, 6], + "index": [0, 0, 1, 1, 1, 3], + "indptr": [0, 2, 5, 5, 6], + "sum": [3, 12, 0, 6], + "add": [3, 12, 0, 6], + "mean": [1.5, 4, 0, 6], + "min": [1, 3, 0, 6], + "arg_min": [0, 2, 6, 5], + "max": [2, 5, 0, 6], + "arg_max": [1, 4, 6, 5], + "sum_grad": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "add_grad": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "mean_grad": [0.5000, 0.5000, 0.33333333, 0.33333333, 0.33333333, 1.0000], + "min_grad": [1.0, 0.0, 1.0, 0.0, 0.0, 1.0], + "max_grad": [0.0, 1.0, 0.0, 0.0, 1.0, 1.0], + }, + { + "src": [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], + "index": [0, 0, 1, 1, 1, 3], + "indptr": [0, 2, 5, 5, 6], + "sum": [[4, 6], [21, 24], [0, 0], [11, 12]], + "add": [[4, 6], [21, 24], [0, 0], [11, 12]], + "mean": [[2, 3], [7, 8], [0, 0], [11, 12]], + "min": [[1, 2], [5, 6], [0, 0], [11, 12]], + "arg_min": [[0, 0], [2, 2], [6, 6], [5, 5]], + "max": [[3, 4], [9, 10], [0, 0], [11, 12]], + "arg_max": [[1, 1], [4, 4], [6, 6], [5, 5]], + "sum_grad": [ + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ], + "add_grad": [ + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ], + "mean_grad": [ + [0.5000, 0.5000], + [0.5000, 0.5000], + [0.33333333, 0.33333333], + [0.33333333, 0.33333333], + [0.33333333, 0.33333333], + [1.0000, 1.0000], + ], + "min_grad": [ + [1.0, 1.0], + [0.0, 0.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [1.0, 1.0], + ], + "max_grad": [ + [0.0, 0.0], + [1.0, 1.0], + [0.0, 0.0], + [0.0, 0.0], + [1.0, 1.0], + [1.0, 1.0], + ], + }, + { + "src": [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], + "index": [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], + "indptr": [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], + "sum": [[4, 21, 0, 11], [12, 18, 12, 0]], + "add": [[4, 21, 0, 11], [12, 18, 12, 0]], + "mean": [[2, 7, 0, 11], [4, 9, 12, 0]], + "min": [[1, 5, 0, 11], [2, 8, 12, 0]], + "arg_min": [[0, 2, 6, 5], [0, 3, 5, 6]], + "max": [[3, 9, 0, 11], [6, 10, 12, 0]], + "arg_max": [[1, 4, 6, 5], [2, 4, 5, 6]], + "sum_grad": [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], + "add_grad": [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], + "mean_grad": [ + [0.5000, 0.5000, 0.33333333, 0.33333333, 0.33333333, 1.0000], + [0.33333333, 0.33333333, 0.33333333, 0.5000, 0.5000, 1.0000], + ], + "min_grad": [[1.0, 0.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0, 0.0, 1.0]], + "max_grad": [[0.0, 1.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 0.0, 1.0, 1.0]], + }, + { + "src": [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], + "index": [[0, 0, 1], [0, 2, 2]], + "indptr": [[0, 2, 3, 3], [0, 1, 1, 3]], + "sum": [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + "add": [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], + "mean": [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], + "min": [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], + "arg_min": [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]], + "max": [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]], + "arg_max": [[[1, 1], [2, 2], [3, 3]], [[0, 0], [3, 3], [2, 2]]], + "sum_grad": [ + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + ], + "add_grad": [ + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + ], + "mean_grad": [ + [[0.5000, 0.5000], [0.5000, 0.5000], [1.0000, 1.0000]], + [[1.0000, 1.0000], [0.5000, 0.5000], [0.5000, 0.5000]], + ], + "min_grad": [ + [[1.0, 1.0], [0.0, 0.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0], [0.0, 0.0]], + ], + "max_grad": [ + [[0.0, 0.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [0.0, 0.0], [1.0, 1.0]], + ], + }, + { + "src": [[1, 3], [2, 4]], + "index": [[0, 0], [0, 0]], + "indptr": [[0, 2], [0, 2]], + "sum": [[4], [6]], + "add": [[4], [6]], + "mean": [[2], [3]], + "min": [[1], [2]], + "arg_min": [[0], [0]], + "max": [[3], [4]], + "arg_max": [[1], [1]], + "sum_grad": [[1.0, 1.0], [1.0, 1.0]], + "add_grad": [[1.0, 1.0], [1.0, 1.0]], + "mean_grad": [[0.5000, 0.5000], [0.5000, 0.5000]], + "min_grad": [[1.0, 0.0], [1.0, 0.0]], + "max_grad": [[0.0, 1.0], [0.0, 1.0]], + }, + { + "src": [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], + "index": [[0, 0], [0, 0]], + "indptr": [[0, 2], [0, 2]], + "sum": [[[4, 4]], [[6, 6]]], + "add": [[[4, 4]], [[6, 6]]], + "mean": [[[2, 2]], [[3, 3]]], + "min": [[[1, 1]], [[2, 2]]], + "arg_min": [[[0, 0]], [[0, 0]]], + "max": [[[3, 3]], [[4, 4]]], + "arg_max": [[[1, 1]], [[1, 1]]], + "sum_grad": [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + "add_grad": [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + "mean_grad": [ + [[0.5000, 0.5000], [0.5000, 0.5000]], + [[0.5000, 0.5000], [0.5000, 0.5000]], + ], + "min_grad": [[[1.0, 1.0], [0.0, 0.0]], [[1.0, 1.0], [0.0, 0.0]]], + "max_grad": [[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]], + }, +] + + +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype,place", + product(tests, reductions, dtypes, ind_dtypes, places), +) +def test_forward(test, reduce, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + fn = getattr(paddle_scatter, "segment_" + reduce + "_csr") + out = fn(src, indptr) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + fn = getattr(paddle_scatter, "segment_" + reduce + "_coo") + out = fn(src, index) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + +@pytest.mark.parametrize("test,reduce,place", product(tests, reductions, places)) +def test_backward(test, reduce, place): + paddle.set_device(place) + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + exp_grad = tensor(test[f"{reduce}_grad"], paddle.float64) + + src = tensor(test["src"], paddle.float64) + src.stop_gradient = False + out = paddle_scatter.segment_csr(src, indptr, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + src = tensor(test["src"], paddle.float64) + src.stop_gradient = False + out = paddle_scatter.segment_coo(src, index, None, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype,place", + product(tests, reductions, dtypes, ind_dtypes, places), +) +def test_out(test, reduce, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + out = paddle.full_like(expected, -2) + + getattr(paddle_scatter, "segment_" + reduce + "_csr")(src, indptr, out) + np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) + + out.fill_(-2) + + getattr(paddle_scatter, "segment_" + reduce + "_coo")(src, index, out) + + if reduce == "sum" or reduce == "add": + expected = expected - 2 + elif reduce == "mean": + expected = out # We can not really test this here. + elif reduce == "min": + expected = expected.fill_(-2) + elif reduce == "max": + expected[expected == 0] = -2 + else: + raise ValueError + + assert paddle.all(out == expected) + + +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype,place", + product(tests, reductions, dtypes, ind_dtypes, places), +) +def test_non_contiguous(test, reduce, dtype, ind_dtype, place): + paddle.set_device(place) + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + if indptr.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + indptr = indptr.transpose(shape).contiguous().transpose(shape) + + out = getattr(paddle_scatter, "segment_" + reduce + "_csr")(src, indptr) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + out = getattr(paddle_scatter, "segment_" + reduce + "_coo")(src, index) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/tests/test_zero_tensors.py b/jointContribution/paddle_scatter/tests/test_zero_tensors.py new file mode 100644 index 000000000..2aaba1c44 --- /dev/null +++ b/jointContribution/paddle_scatter/tests/test_zero_tensors.py @@ -0,0 +1,45 @@ +from itertools import product + +import paddle +import pytest +from paddle_scatter import gather_coo +from paddle_scatter import gather_csr +from paddle_scatter import scatter +from paddle_scatter import segment_coo +from paddle_scatter import segment_csr +from paddle_scatter.testing import grad_dtypes +from paddle_scatter.testing import ind_dtypes +from paddle_scatter.testing import places +from paddle_scatter.testing import reductions +from paddle_scatter.testing import tensor + + +@pytest.mark.parametrize( + "reduce,dtype,ind_dtype,place", product(reductions, grad_dtypes, ind_dtypes, places) +) +def test_zero_elements(reduce, dtype, ind_dtype, place): + paddle.set_device(place) + x = paddle.randn([0, 0, 0, 16], dtype=dtype) + x.stop_gradient = False + index = tensor([], ind_dtype) + indptr = tensor([], ind_dtype) + + out = scatter(x, index, dim=0, dim_size=0, reduce=reduce) + out.backward(paddle.randn(out.shape, out.dtype)) + assert out.shape == [0, 0, 0, 16] + + out = segment_coo(x, index, dim_size=0, reduce=reduce) + out.backward(paddle.randn(out.shape, out.dtype)) + assert out.shape == [0, 0, 0, 16] + + out = gather_coo(x, index) + out.backward(paddle.randn(out.shape, out.dtype)) + assert out.shape == [0, 0, 0, 16] + + out = segment_csr(x, indptr, reduce=reduce) + out.backward(paddle.randn(out.shape, out.dtype)) + assert out.shape == [0, 0, 0, 16] + + out = gather_csr(x, indptr) + out.backward(paddle.randn(out.shape, out.dtype)) + assert out.shape == [0, 0, 0, 16] diff --git a/jointContribution/paddle_scatter/utils.py b/jointContribution/paddle_scatter/utils.py new file mode 100644 index 000000000..b2db4874f --- /dev/null +++ b/jointContribution/paddle_scatter/utils.py @@ -0,0 +1,73 @@ +import paddle + + +def broadcast(src: paddle.Tensor, other: paddle.Tensor, dim: int) -> paddle.Tensor: + r"""Broadcast `src` to `other` at dimension `dim`. + + Denote dim = :math:`i`, + other.shape = :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`, + src.shape = :math:(x_i,)`, src = :math:`[y_0, ..., y_{x_i-1}]`, + where each element satisfying 0 <= element < x_i + + This util function broadcast `src` to the shape of `other`'s. + + Notes: + The shape of `src` should be either :math:`(x_i,)` or :math:`(*_0, ..., *_{i-1}, x_i)`, + where :math:`*_k (k <= i-1)` should be either :math:`1` or :math:`x_k`. + + Args: + src (paddle.Tensor): The tensor to be broadcasted. + other (paddle.Tensor): The tensor to be broadcasted to. + dim (int): The target dimension of `other`. + + Returns: + paddle.Tensor, the broadcasted tensor. + """ + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + if other.numel() == 0: + return src.reshape(other.shape) + src = src.expand(other.shape) + return src + + +def transform_3d(tensor: paddle.Tensor, dim: int) -> paddle.Tensor: + r"""Transform a tensor to [pre, dim, post] 3d dimensions. + + Args: + tensor (paddle.Tensor): The input tensor. + dim (int): The target dimension. + + Returns: + paddle.Tensor, the transformed tensor. + """ + if tensor.dim() == 1: + return tensor.unsqueeze(0).unsqueeze(-1) + elif tensor.dim() == 2: + if dim == 0: + return tensor.unsqueeze(0) + elif dim == 1: + return tensor.unsqueeze(-1) + else: + return tensor.flatten(0, dim - 1).flatten(2, -1) + + +def transform_2d(tensor: paddle.Tensor, dim: int) -> paddle.Tensor: + r"""Transform a tensor to [pre, dim] 2d dimensions. + + Args: + tensor (paddle.Tensor): The input tensor. + dim (int): The target dimension. + + Returns: + paddle.Tensor, the transformed tensor. + """ + if tensor.dim() == 1: + return tensor.unsqueeze(0) + else: + return tensor.flatten(0, dim - 1) From 3993fb4b0c19d9226a31b16bde0f986acb9db74d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 2 Dec 2024 11:22:14 +0800 Subject: [PATCH 2/2] update --- .../paddle_scatter/composite/logsumexp.py | 2 +- .../paddle_scatter/composite/softmax.py | 2 +- .../paddle_scatter/composite/std.py | 2 +- .../paddle_scatter/csrc/atomics.cuh | 31 + .../paddle_scatter/csrc/index_info.cuh | 7 +- .../paddle_scatter/csrc/index_info.h | 7 +- .../paddle_scatter/csrc/scatter_min_max.cu | 2 +- .../paddle_scatter/csrc/segment_coo.cc | 500 ++++++++++++++ .../paddle_scatter/csrc/segment_coo.cu | 610 ++++++++++++++++++ .../paddle_scatter/csrc/segment_csr.cc | 409 ++++++++++++ .../paddle_scatter/csrc/segment_csr.cu | 382 +++++++++++ .../paddle_scatter/csrc/utils.cuh | 2 +- jointContribution/paddle_scatter/scatter.py | 73 ++- .../paddle_scatter/segment_coo.py | 195 +++--- .../paddle_scatter/segment_csr.py | 161 ++--- jointContribution/paddle_scatter/setup.py | 4 +- jointContribution/paddle_scatter/testing.py | 9 +- .../tests/composite/test_softmax.py | 2 +- .../paddle_scatter/tests/test_gather.py | 115 ++++ .../paddle_scatter/tests/test_multi_gpu.py | 4 +- .../paddle_scatter/tests/test_scatter.py | 122 +++- .../paddle_scatter/tests/test_segment.py | 147 +++++ jointContribution/paddle_scatter/utils.py | 37 -- 23 files changed, 2501 insertions(+), 324 deletions(-) create mode 100644 jointContribution/paddle_scatter/csrc/segment_coo.cc create mode 100644 jointContribution/paddle_scatter/csrc/segment_coo.cu create mode 100644 jointContribution/paddle_scatter/csrc/segment_csr.cc create mode 100644 jointContribution/paddle_scatter/csrc/segment_csr.cu diff --git a/jointContribution/paddle_scatter/composite/logsumexp.py b/jointContribution/paddle_scatter/composite/logsumexp.py index 3d66d040a..a3bed5c09 100644 --- a/jointContribution/paddle_scatter/composite/logsumexp.py +++ b/jointContribution/paddle_scatter/composite/logsumexp.py @@ -42,7 +42,7 @@ def scatter_logsumexp( ) index = broadcast(index, src, dim) - eps = paddle.to_tensor(eps, dtype=src.dtype) + eps = paddle.full([], eps, dtype=src.dtype) if out is not None: dim_size = out.shape[dim] diff --git a/jointContribution/paddle_scatter/composite/softmax.py b/jointContribution/paddle_scatter/composite/softmax.py index b31b54d44..dfa2eaa8c 100644 --- a/jointContribution/paddle_scatter/composite/softmax.py +++ b/jointContribution/paddle_scatter/composite/softmax.py @@ -84,7 +84,7 @@ def scatter_log_softmax( ) index = broadcast(index, src, dim) - eps = paddle.to_tensor(eps, dtype=src.dtype) + eps = paddle.full([], eps, dtype=src.dtype) max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) diff --git a/jointContribution/paddle_scatter/composite/std.py b/jointContribution/paddle_scatter/composite/std.py index 359e0b50d..6d27f4ad4 100644 --- a/jointContribution/paddle_scatter/composite/std.py +++ b/jointContribution/paddle_scatter/composite/std.py @@ -57,7 +57,7 @@ def scatter_std( res = scatter_sum(var, index, dim, out, dim_size) if unbiased: - count = count.subtract(paddle.to_tensor(1, dtype=src.dtype)).clip(1) + count = count.subtract(paddle.full([], 1, dtype=src.dtype)).clip(1) res = res.divide(count + 1e-6).sqrt() if out is not None: diff --git a/jointContribution/paddle_scatter/csrc/atomics.cuh b/jointContribution/paddle_scatter/csrc/atomics.cuh index 224d29889..f5b5de08d 100644 --- a/jointContribution/paddle_scatter/csrc/atomics.cuh +++ b/jointContribution/paddle_scatter/csrc/atomics.cuh @@ -150,3 +150,34 @@ static inline __device__ void atomMin(float *address, float val) { static inline __device__ void atomMin(double *address, double val) { AtomicMinDecimalImpl()(address, val); } + +#define OP(X, Y) Y + X +ATOMIC(Add) +#undef OP +static inline __device__ void atomAdd(uint8_t *address, uint8_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int8_t *address, int8_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int16_t *address, int16_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(int32_t *address, int32_t val) { + atomicAdd(address, val); +} +static inline __device__ void atomAdd(int64_t *address, int64_t val) { + AtomicAddIntegerImpl()(address, val); +} +static inline __device__ void atomAdd(float *address, float val) { + atomicAdd(address, val); +} +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) +static inline __device__ void atomAdd(double *address, double val) { + AtomicAddDecimalImpl()(address, val); +} +#else +static inline __device__ void atomAdd(double *address, double val) { + atomicAdd(address, val); +} +#endif diff --git a/jointContribution/paddle_scatter/csrc/index_info.cuh b/jointContribution/paddle_scatter/csrc/index_info.cuh index 64ecd28de..f7bb665af 100644 --- a/jointContribution/paddle_scatter/csrc/index_info.cuh +++ b/jointContribution/paddle_scatter/csrc/index_info.cuh @@ -2,14 +2,14 @@ #include "paddle/extension.h" -#define MAX_TENSORINFO_DIMS 25 +#define MAX_TENSORINFO_DIMS 7 template struct TensorInfo { TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS], int st[MAX_TENSORINFO_DIMS]) { data = p; dims = dim; - PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25."); + PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 7."); for (int i = 0; i < dim; ++i) { sizes[i] = sz[i]; @@ -30,11 +30,8 @@ TensorInfo getTensorInfo(const paddle::Tensor &tensor) { int strides[MAX_TENSORINFO_DIMS]; int dims = tensor.shape().size(); - // int stride_i = 1; for (int i = dims - 1; i >= 0; --i) { sizes[i] = tensor.shape()[i]; - // strides[i] = stride_i; - // stride_i *= sizes[i]; sizes[i] = tensor.strides()[i]; } diff --git a/jointContribution/paddle_scatter/csrc/index_info.h b/jointContribution/paddle_scatter/csrc/index_info.h index c2751898b..3d27cacc7 100644 --- a/jointContribution/paddle_scatter/csrc/index_info.h +++ b/jointContribution/paddle_scatter/csrc/index_info.h @@ -2,14 +2,14 @@ #include "paddle/extension.h" -#define MAX_TENSORINFO_DIMS 25 +#define MAX_TENSORINFO_DIMS 7 template struct TensorInfo { TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS], int st[MAX_TENSORINFO_DIMS]) { data = p; dims = dim; - PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25."); + PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 7."); for (int i = 0; i < dim; ++i) { sizes[i] = sz[i]; @@ -29,11 +29,8 @@ TensorInfo getTensorInfo(const paddle::Tensor &tensor) { int strides[MAX_TENSORINFO_DIMS]; int dims = tensor.shape().size(); - // int stride_i = 1; for (int i = dims - 1; i >= 0; --i) { sizes[i] = tensor.shape()[i]; - // strides[i] = stride_i; - // stride_i *= sizes[i]; strides[i] = tensor.strides()[i]; } diff --git a/jointContribution/paddle_scatter/csrc/scatter_min_max.cu b/jointContribution/paddle_scatter/csrc/scatter_min_max.cu index 5fccb6f6a..089def620 100644 --- a/jointContribution/paddle_scatter/csrc/scatter_min_max.cu +++ b/jointContribution/paddle_scatter/csrc/scatter_min_max.cu @@ -135,7 +135,7 @@ std::vector scatter_min_max_cuda_forward(const paddle::Tensor& x using MPType = typename MPTypeTrait::Type; paddle::Tensor out_mp; if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { - out_mp = paddle::empty(return_shape, paddle::DataType::FLOAT32, x.place()); + out_mp = paddle::experimental::cast(out, paddle::DataType::FLOAT32); } else { out_mp = out; } diff --git a/jointContribution/paddle_scatter/csrc/segment_coo.cc b/jointContribution/paddle_scatter/csrc/segment_coo.cc new file mode 100644 index 000000000..f763d50e9 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_coo.cc @@ -0,0 +1,500 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void segment_coo_cpu_forward_kernel(const data_t* x_data, + const index_t* index_data, + const std::vector& return_shape, + const std::vector& x_dims, + const std::string& reduce, + const TensorInfo& index_info, + int64_t x_numel, + int64_t dim, + bool post_process, + data_t* out_data, + index_t* arg_out_data, + data_t* count_data) { + using MPType = typename MPTypeTrait::Type; + auto B = 1; + for (auto i = 0; i < dim; ++i) + B *= x_dims[i]; + auto E = x_dims[dim]; + auto K = x_numel / (B * E); + auto N = return_shape[dim]; + + auto stride = index_info.strides[index_info.dims - 1]; + std::vector args(K); + std::vector vals(K); + + int64_t idx, next_idx, row_start; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = static_cast(out_data[b * N * K + k]); + + row_start = 0; + for (auto e = 0; e < E; e++) { + // update + for (auto k = 0; k < K; k++) { + auto cmp = static_cast(x_data[b * E * K + e * K + k]); + if ((reduce == "min" && cmp < vals[k]) || + (reduce == "max" && cmp > vals[k])) { + vals[k] = cmp; + args[k] = e; + } else if (reduce == "sum" || reduce == "mean") { + vals[k] += cmp; + } + } + //write + if (e == E - 1) { + for (auto k = 0; k < K; k++) { + auto idx_k = b * N * K + idx * K + k; + auto count = E - row_start; + if (reduce == "sum") { + out_data[idx_k] = static_cast(vals[k]); + } else if (reduce == "mean") { + out_data[idx_k] = static_cast(vals[k] / static_cast(count > 0 ? count : 1)); + } else if (reduce == "min" || reduce == "max") { + if (count > 0) { + out_data[idx_k] = static_cast(vals[k]); + arg_out_data[idx_k] = args[k]; + } else { + out_data[idx_k] = static_cast(0); + } + } + } + + if (reduce == "mean") + count_data[b * N + idx] = static_cast(e + 1 - row_start); + } else { + next_idx = index_info.data[offset + (e + 1) * stride]; + assert(idx <= next_idx); + + if (idx != next_idx) { + //write + for (auto k = 0; k < K; k++) { + auto idx_k = b * N * K + idx * K + k; + auto count = e + 1 - row_start; + if (reduce == "sum") { + out_data[idx_k] = static_cast(vals[k]); + } else if (reduce == "mean") { + out_data[idx_k] = static_cast(vals[k] / static_cast(count > 0 ? count : 1)); + } else if (reduce == "min" || reduce == "max") { + if (count > 0) { + out_data[idx_k] = static_cast(vals[k]); + arg_out_data[idx_k] = args[k]; + } else { + out_data[idx_k] = static_cast(0); + } + } + + vals[k] = static_cast(out_data[b * N * K + next_idx * K + k]); + } + if (reduce == "mean") + count_data[b * N + idx] = static_cast(e + 1 - row_start); + row_start = e + 1; + } + + idx = next_idx; + } + } + } + + if (post_process) { + if (reduce == "min" || reduce == "max") { + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + for (auto i = 0; i < out_numel; ++i) { + if (out_data[i] == init_val) + out_data[i] = static_cast(0.0); + } + } + if (reduce == "mean") { + auto count_data_numel = sizeof(count_data) / sizeof(data_t); + for (auto i = 0; i < count_data_numel; ++i) { + if (count_data[i] < static_cast(1.0)) + count_data[i] = static_cast(1.0); + } + } + } +} + +std::vector segment_coo_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CPU(index); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + } else if (reduce == "mean") { + auto sizes = index.shape(); + sizes[dim] = return_shape[dim]; + arg_out = paddle::zeros(sizes, out.dtype(), index.place()); + } + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "segment_coo_cpu_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out, static_cast(std::numeric_limits::max())); + else if (reduce == "max") + paddle::experimental::fill_(out, static_cast(std::numeric_limits::lowest())); + else if (reduce == "sum" || reduce == "mean") + paddle::experimental::fill_(out, static_cast(0)); + } + + bool post_process = (!init) ? true : false; + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + segment_coo_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr, + (reduce == "mean") ? arg_out.data() : nullptr); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + segment_coo_cpu_forward_kernel( + x.data(), + index.data(), + return_shape, + x_dims, + reduce, + index_info, + x.numel(), + dim, + post_process, + out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr, + (reduce == "mean") ? arg_out.data() : nullptr); + break; + } + default: + PD_THROW( + "function segment_coo_cpu_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + +template +void gather_coo_cpu_forward_kernel(const data_t* x_data, + const TensorInfo& index_info, + int stride, + int B, + int E, + int K, + int N, + data_t* out_data) { + std::vector vals(K); + int64_t idx, next_idx; + for (auto b = 0; b < B; b++) { + auto offset = IndexToOffset::get(b * E, index_info); + idx = index_info.data[offset]; + + for (auto k = 0; k < K; k++) + vals[k] = x_data[b * N * K + idx * K + k]; + + for (auto e = 0; e < E; e++) { + for (auto k = 0; k < K; k++) + out_data[b * E * K + e * K + k] = vals[k]; + + if (e < E - 1) { + next_idx = index_info.data[offset + (e + 1) * stride]; + CHECK_INPUT(idx <= next_idx); + + if (idx != next_idx) { + idx = next_idx; + for (auto k = 0; k < K; k++) + vals[k] = x_data[b * N * K + idx * K + k]; + } + } + } + } +} + +std::vector gather_coo_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape) { + CHECK_CPU(index); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] == index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto B = index.numel() / return_shape[dim]; + auto E = index_dims[dim]; + auto K = out.numel() / index.numel(); + auto N = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "gather_coo_cpu_forward_kernel", ([&] { + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + gather_coo_cpu_forward_kernel( + x.data(), index_info, stride, + B, E, K, N, out.data()); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + gather_coo_cpu_forward_kernel( + x.data(), index_info, stride, + B, E, K, N, out.data()); + break; + } + default: + PD_THROW( + "function gather_coo_cpu_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out}; +} + +std::vector segment_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce); + +std::vector SegmentCooForward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce) { + if (x.is_cpu()) { + return segment_coo_cpu_forward(x, index, init, return_shape, reduce); + } else if (x.is_gpu()) { + return segment_coo_cuda_forward(x, index, init, return_shape, reduce); + } else { + PD_THROW("Unsupported device type for forward function of custom segment_coo operator."); + } +} + + +std::vector gather_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape); + +std::vector GatherCooForward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape) { + if (x.is_cpu()) { + return gather_coo_cpu_forward(x, index, init, return_shape); + } else if (x.is_gpu()) { + return gather_coo_cuda_forward(x, index, init, return_shape); + } else { + PD_THROW("Unsupported device type for forward function of custom gather_coo operator."); + } +} + +std::vector SegmentCooBackward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& arg_out, + const paddle::Tensor& grad_out, + std::string reduce) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom segment_coo operator."); + } + if (reduce == "min" || reduce == "max") { + int64_t dim = index.shape().size() - 1; + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out.get(), grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; + } else if (reduce == "mean") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + paddle::Tensor count = arg_out.get(); + paddle::Tensor grad_in = GatherCooForward(grad_out, index, paddle::optional(grad_x), grad_x.shape())[0]; + auto sizes = arg_out.get().shape(); + int64_t dim = index.shape().size() - 1; + sizes[dim] = index.shape()[dim]; + count = GatherCooForward(count, index, paddle::optional(paddle::none), sizes)[0]; + for (auto i = 0; i < grad_out.shape().size() - index.shape().size(); i++) + count = paddle::experimental::unsqueeze(count, {-1}); + paddle::experimental::divide_(grad_in, count); + return {grad_in}; + } else if (reduce == "sum") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + paddle::Tensor grad_in = GatherCooForward(grad_out, index, paddle::optional(grad_x), grad_x.shape())[0]; + return {grad_in}; + } +} + +std::vector GatherCooBackward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::Tensor& grad_out) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom gather_coo operator."); + } + auto x_shape = x.shape(); + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::Tensor grad_in = SegmentCooForward(grad_out, index, paddle::optional(grad_x), grad_x.shape(), "sum")[0]; + return {grad_in}; +} + +std::vector> SegmentCooFWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& init_shape, + std::vector return_shape, + std::string reduce) { + return {return_shape, return_shape}; +} + +std::vector SegmentCooFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& init_dtype) { + return {x_dtype, index_dtype}; +} + +std::vector> SegmentCooBWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& arg_out_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector SegmentCooBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& arg_out_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +std::vector> GatherCooFWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const paddle::optional>& init_shape, + std::vector return_shape) { + return {return_shape}; +} + +std::vector GatherCooFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::optional& init_dtype) { + return {x_dtype}; +} + +std::vector> GatherCooBWInferShape(const std::vector& x_shape, + const std::vector& index_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector GatherCooBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& index_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + + +PD_BUILD_OP(custom_segment_coo) + .Inputs({"X", "Index", paddle::Optional("Init")}) + .Outputs({"Out", paddle::Optional("ArgOut")}) + .Attrs({"return_shape: std::vector", + "reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCooForward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCooFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCooFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_segment_coo) + .Inputs({"X", "Index", paddle::Optional("ArgOut"), paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCooBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCooBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCooBWInferDtype)); + +PD_BUILD_OP(custom_gather_coo) + .Inputs({"X", "Index", paddle::Optional("Init")}) + .Outputs({"Out"}) + .Attrs({"return_shape: std::vector"}) + .SetKernelFn(PD_KERNEL(GatherCooForward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCooFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCooFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_gather_coo) + .Inputs({"X", "Index", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(GatherCooBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCooBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCooBWInferDtype)); diff --git a/jointContribution/paddle_scatter/csrc/segment_coo.cu b/jointContribution/paddle_scatter/csrc/segment_coo.cu new file mode 100644 index 000000000..959331128 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_coo.cu @@ -0,0 +1,610 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + + +enum ReductionType { MIN, MAX, SUM, MEAN }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX}, {"sum", SUM}, {"mean", MEAN} +}; + +bool is_floating_point(const phi::DataType& dtype) { + return dtype == phi::DataType::BFLOAT16 || dtype == phi::DataType::FLOAT16 || + dtype == phi::DataType::FLOAT32 || dtype == phi::DataType::FLOAT64; +} + +template +__global__ void +segment_coo_cuda_forward_kernel(const data_t* x_data, + const TensorInfo index_info, + ReductionType reduce_type, + mp_t* out_data, + size_t E, + size_t N) { + + // Each thread processes exactly one entry. Within a warp, we perform a + // parallel reduction across equal indices, and write the intermediate + // result via atomics. + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int lane_idx = row_idx & (32 - 1); + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = IndexToOffset::get(row_idx, index_info); + int64_t idx = index_info.data[offset], next_idx; + int out_idx = (row_idx / D) * N + idx; + + mp_t val = HAS_VAL ? static_cast(x_data[row_idx]) : (mp_t)1, tmp; + +#pragma unroll + for (int i = 1; i < 32; i *= 2) { + // Parallel reduction inside a single warp. + tmp = SHFL_UP_SYNC(FULL_MASK, val, i); + next_idx = SHFL_UP_SYNC(FULL_MASK, idx, i); + if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { + assert(idx >= next_idx); + if (idx == next_idx) { + // update + if ((reduce_type == MIN && tmp < val) || + (reduce_type == MAX && tmp > val)) + val = tmp; + else if (reduce_type == SUM || reduce_type == MEAN) + val += tmp; + } + } + } + + next_idx = SHFL_DOWN_SYNC(FULL_MASK, idx, 1); + if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || + idx != next_idx) { + // atomic_write + switch(reduce_type) { + case MIN: + atomMin(out_data + out_idx, val); + break; + case MAX: + atomMax(out_data + out_idx, val); + break; + case SUM: + atomAdd(out_data + out_idx, val); + break; + case MEAN: + atomAdd(out_data + out_idx, val); + break; + } + } + } +} + +template +__global__ void segment_coo_broadcast_cuda_forward_kernel(const data_t *x_data, + const TensorInfo index_info, + ReductionType reduce_type, + mp_t *out_data, + size_t E, + size_t K, + size_t N) { + + // Each thread processes a single column and `TB` index entries. Coalesced + // read and write is performed in column-major order. The intermediate + // results are written via atomics. + + int D = index_info.sizes[index_info.dims - 1]; + int E_1 = E / D; + int E_2 = (D - 1) + TB - ((D - 1) % TB); + + int row_idx = blockIdx.x * blockDim.y + threadIdx.y; + int col_idx = blockIdx.y * blockDim.x + threadIdx.x; + + int dim_start = (row_idx * TB) / E_2; + int row_start = (row_idx * TB) % E_2; + + if (dim_start < E_1 && col_idx < K) { + + int offset = IndexToOffset::get(dim_start * D + row_start, index_info); + int idx1 = __ldg(index_info.data + offset), idx2; + + mp_t val = static_cast(x_data[K * (dim_start * D + row_start) + col_idx]); + +#pragma unroll + for (int i = 1; i < TB; i++) { + if (row_start + i >= D) + break; + + idx2 = __ldg(index_info.data + offset + + i * index_info.strides[index_info.dims - 1]); + assert(idx1 <= idx2); + if (idx1 == idx2) { + mp_t tmp = static_cast(x_data[K * (dim_start * D + row_start + i) + col_idx]); + // update + if ((reduce_type == MIN && tmp < val) || + (reduce_type == MAX && tmp > val)) + val = tmp; + else if (reduce_type == SUM || reduce_type == MEAN) + val += tmp; + } else { + // atomic_write + switch(reduce_type) { + case MIN: + atomMin(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MAX: + atomMax(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case SUM: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MEAN: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + val = x_data[K * (dim_start * D + row_start + i) + col_idx]; + } + + idx1 = idx2; + } + // atomic_write + switch(reduce_type) { + case MIN: + atomMin(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MAX: + atomMax(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case SUM: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + case MEAN: + atomAdd(out_data + (dim_start * N + idx1) * K + col_idx, val); + break; + } + } +} + + +template +__global__ void segment_coo_arg_kernel(const data_t *x_data, + const TensorInfo index_info, + mp_t *out_data, + index_t *arg_out_data, + size_t E, + size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E) { + int offset = IndexToOffset::get(row_idx, index_info); + index_t idx = index_info.data[offset]; + int out_idx = (row_idx / D) * N + idx; + + mp_t val = __ldg(out_data + out_idx); + if (static_cast(x_data[row_idx]) == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void segment_coo_arg_broadcast_kernel(const data_t *x_data, + const TensorInfo index_info, + mp_t *out_data, + index_t *arg_out_data, + size_t E, + size_t K, + size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + int D = index_info.sizes[index_info.dims - 1]; + + if (row_idx < E && col_idx < K) { + int offset = IndexToOffset::get(row_idx, index_info); + int idx = __ldg(index_info.data + offset); + int out_idx = ((row_idx / D) * N + idx) * K + col_idx; + + mp_t val = __ldg(out_data + out_idx); + if (static_cast(x_data[thread_idx]) == val) + arg_out_data[out_idx] = row_idx % D; + } +} + +template +__global__ void post_process_kernel(data_t init_val, + int numel, + data_t* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < numel) { + if (out_data[tid] == init_val) + out_data[tid] = static_cast(0.0); + } +} + +template +__global__ void post_process_mean_kernel(mp_t* count_data, + int numel) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < numel) { + if (count_data[tid] < static_cast(1.0)) + count_data[tid] = static_cast(1.0); + } +} + +std::vector segment_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape, + std::string reduce) { + CHECK_CUDA(index); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] >= index_dims[i]); + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto dim = index_dims.size() - 1; + paddle::Tensor arg_out; + int count_numel; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], index.dtype(), index.place()); + } else if (reduce == "mean") { + auto sizes = index.shape(); + sizes[dim] = return_shape[dim]; + arg_out = paddle::zeros(sizes, out.dtype(), index.place()); + count_numel = std::accumulate(sizes.begin(), sizes.end(), 1.0, std::multiplies()); + } + + auto E = index.numel(); + auto E_2 = index.shape()[dim]; + auto E_1 = index.numel() / E_2; + auto K = x.numel() / E; + auto N = out.shape()[dim]; + auto avg_len = (float)E_2 / (float)N; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "segment_coo_cuda_forward_kernel", ([&] { + + using MPType = typename MPTypeTrait::Type; + paddle::Tensor out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out_mp = paddle::experimental::cast(out, paddle::DataType::FLOAT32); + } else { + out_mp = out; + } + + if (!init) { + if (reduce == "min") + paddle::experimental::fill_(out_mp, std::numeric_limits::max()); + else if (reduce == "max") + paddle::experimental::fill_(out_mp, std::numeric_limits::lowest()); + else if (reduce == "sum" || reduce == "mean") + paddle::experimental::fill_(out_mp, static_cast(0)); + } + + const data_t* x_data = x.data(); + MPType* out_data = out_mp.data(); + auto out_numel = std::accumulate(return_shape.begin(), return_shape.end(), 1.0, std::multiplies()); + + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + if (K == 1) + segment_coo_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, N); + + else if (avg_len <= 8) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 16) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 32) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + + if (reduce == "min" || reduce == "max") { + int* arg_out_data = arg_out.data(); + if (K == 1) + segment_coo_arg_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, K, N); + } + + if (reduce == "mean") { + paddle::Tensor arg_out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out_mp = paddle::experimental::cast(arg_out, paddle::DataType::FLOAT32); + } else { + arg_out_mp = arg_out; + } + auto count_data = arg_out_mp.data(); + segment_coo_cuda_forward_kernel + <<>>( + nullptr, index_info, reduce2REDUCE.at(reduce), + count_data, E, N); + post_process_mean_kernel + <<>>( + count_data, count_numel); + paddle::Tensor count = arg_out_mp; + for (int i = dim + 1; i < return_shape.size(); i++) { + count = paddle::experimental::unsqueeze(arg_out_mp, {-1}); + } + if (is_floating_point(out.dtype())) + paddle::experimental::divide_(out_mp, count); + else + paddle::experimental::floor_divide_(out_mp, count); + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out = paddle::experimental::cast(arg_out_mp, x.dtype()); + } + } + + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + if (K == 1) + segment_coo_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, N); + + else if (avg_len <= 8) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 16) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else if (avg_len <= 32) + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + else + segment_coo_broadcast_cuda_forward_kernel + <<>>( + x_data, index_info, reduce2REDUCE.at(reduce), out_data, E, K, N); + + if (reduce == "min" || reduce == "max") { + int64_t* arg_out_data = arg_out.data(); + if (K == 1) + segment_coo_arg_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, N); + else + segment_coo_arg_broadcast_kernel + <<>>( + x_data, index_info, out_data, arg_out_data, E, K, N); + } + + if (reduce == "mean") { + paddle::Tensor arg_out_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out_mp = paddle::experimental::cast(arg_out, paddle::DataType::FLOAT32); + } else { + arg_out_mp = arg_out; + } + auto count_data = arg_out_mp.data(); + segment_coo_cuda_forward_kernel + <<>>( + nullptr, index_info, reduce2REDUCE.at(reduce), + count_data, E, N); + post_process_mean_kernel + <<>>( + count_data, count_numel); + paddle::Tensor count = arg_out_mp; + for (int i = dim + 1; i < return_shape.size(); i++) { + count = paddle::experimental::unsqueeze(arg_out_mp, {-1}); + } + if (is_floating_point(out.dtype())) + paddle::experimental::divide_(out_mp, count); + else + paddle::experimental::floor_divide_(out_mp, count); + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + arg_out = paddle::experimental::cast(arg_out_mp, x.dtype()); + } + } + + break; + } + default: + PD_THROW( + "function segment_coo_cuda_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) { + out = paddle::experimental::cast(out_mp, x.dtype()); + } + + if (!init) { + data_t init_val = static_cast((reduce == "min") ? std::numeric_limits::max() : std::numeric_limits::lowest()); + post_process_kernel + <<>>( + init_val, out_numel, out.data() + ); + } + + })); + + return {out, arg_out}; +} + + +template +__global__ void +gather_coo_kernel(const mp_t *src_data, + const TensorInfo index_info, + data_t *out_data, size_t E, size_t N) { + + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (row_idx < E) { + int offset = IndexToOffset::get( + row_idx, index_info); + int row = index_info.data[offset]; + + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N; + mp_t val = __ldg(src_data + offset + row); + + out_data[row_idx] = static_cast(val); + } +} + +template +__global__ void gather_coo_broadcast_kernel( + const mp_t *src_data, + const TensorInfo index_info, + data_t *out_data, size_t E, size_t K, size_t N) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int col_idx = thread_idx % K; + + if (thread_idx < E * K) { + int offset = IndexToOffset::get( + row_idx, index_info); + int row = index_info.data[offset]; + + offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K; + mp_t val = __ldg(src_data + offset + K * row + col_idx); + + out_data[thread_idx] = static_cast(val); + } +} + + +std::vector gather_coo_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& index, + const paddle::optional& init, + std::vector return_shape) { + CHECK_CUDA(index); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto index_dims = index.shape(); + CHECK_INPUT(x_dims.size() >= index_dims.size()); + for (auto i = 0; i < index_dims.size() - 1; ++i) + CHECK_INPUT(x_dims[i] == index_dims[i]); + + auto dim = index_dims.size() - 1; + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto E = index.numel(); + auto K = out.numel() / E; + auto N = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "gather_coo_cuda_forward_kernel", ([&] { + using MPType = typename MPTypeTrait::Type; + paddle::Tensor x_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) + x_mp = paddle::experimental::cast(x, paddle::DataType::FLOAT32); + else + x_mp = x; + + switch(index.dtype()) { + case paddle::DataType::INT32: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + if (K == 1) + gather_coo_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, N); + else + gather_coo_broadcast_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, K, N); + break; + } + case paddle::DataType::INT64: + { + auto index_info = getTensorInfo(index); + auto stride = index_info.strides[index_info.dims - 1]; + if (K == 1) + gather_coo_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, N); + else + gather_coo_broadcast_kernel + <<>>( + x_mp.data(), index_info, out.data(), E, K, N); + break; + } + default: + PD_THROW( + "function gather_coo_cuda_forward_kernel is not implemented for the index data type `", + phi::DataTypeToString(index.dtype()), "`"); + } + })); + + return {out}; +} \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/segment_csr.cc b/jointContribution/paddle_scatter/csrc/segment_csr.cc new file mode 100644 index 000000000..6a602185e --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_csr.cc @@ -0,0 +1,409 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "index_info.h" +#include "utils.h" + + +template +void segment_csr_cpu_forward_kernel(const data_t* x_data, + const std::vector& indptr_shape, + const TensorInfo& indptr_info, + const std::string& reduce, + int stride, + int64_t dim, + int N, + int K, + int E, + data_t* out_data, + index_t* arg_out_data) { + using MPType = typename MPTypeTrait::Type; + std::vector args(K); + std::vector vals(K); + index_t row_start, row_end; + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + offset = (n / (indptr_shape[dim] - 1)) * E * K; + for (auto k = 0; k < K; k++) { + // init + if (reduce == "min") + vals[k] = static_cast(std::numeric_limits::max()); + else if (reduce == "max") + vals[k] = static_cast(std::numeric_limits::lowest()); + else if (reduce == "sum" || reduce == "mean") + vals[k] = static_cast(0); + } + + for (auto e = row_start; e < row_end; e++) { + for (auto k = 0; k < K; k++) { + // update + auto cmp = static_cast(x_data[offset + e * K + k]); + if ((reduce == "min" && cmp < vals[k]) || + (reduce == "max" && cmp > vals[k])) { + vals[k] = cmp; + args[k] = e; + } else if (reduce == "sum" || reduce == "mean") { + vals[k] += cmp; + } + } + } + + for (auto k = 0; k < K; k++) { + // write + auto idx = n * K + k; + auto count = row_end - row_start; + if (reduce == "sum") { + out_data[idx] = static_cast(vals[k]); + } else if (reduce == "mean") { + out_data[idx] = static_cast(vals[k] / static_cast(count > 0 ? count : 1)); + } else if (reduce == "min" || reduce == "max") { + if (count > 0) { + out_data[idx] = static_cast(vals[k]); + arg_out_data[idx] = args[k]; + } else { + out_data[idx] = static_cast(0); + } + } + } + } +} + +std::vector segment_csr_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CPU(indptr); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], indptr.dtype(), indptr.place()); + } + + auto N = return_shape[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = out.numel() / N; + auto E = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "segment_csr_cpu_forward_kernel", ([&] { + + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + segment_csr_cpu_forward_kernel( + x.data(), indptr_dims, indptr_info, reduce, + stride, dim, N, K, E, out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + segment_csr_cpu_forward_kernel( + x.data(), indptr_dims, indptr_info, reduce, + stride, dim, N, K, E, out.data(), + (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr); + break; + } + default: + PD_THROW( + "function segment_csr_cpu_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + +template +void gather_csr_cpu_forward_kernel(const data_t* x_data, + const TensorInfo& indptr_info, + const std::vector& indptr_shape, + int64_t dim, + int stride, + int N, + int K, + int E, + data_t* out_data) { + std::vector vals(K); + int64_t row_start, row_end; + for (auto n = 0; n < N; n++) { + auto offset = IndexPtrToOffset::get(n, indptr_info); + row_start = indptr_info.data[offset]; + row_end = indptr_info.data[offset + stride]; + + for (auto k = 0; k < K; k++) + vals[k] = x_data[n * K + k]; + + offset = (n / (indptr_shape[dim] - 1)) * E * K; + for (auto e = row_start; e < row_end; e++) + for (auto k = 0; k < K; k++) + out_data[offset + e * K + k] = vals[k]; + } +} + +std::vector gather_csr_cpu_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape) { + CHECK_CPU(indptr); + if (init) + CHECK_CPU(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto N = x_dims[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = x.numel() / N; + auto E = return_shape[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.dtype(), "gather_csr_cpu_forward_kernel", ([&] { + + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + gather_csr_cpu_forward_kernel( + x.data(), indptr_info, indptr_dims, dim, + stride, N, K, E, out.data()); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int stride = indptr_info.strides[indptr_info.dims - 1]; + gather_csr_cpu_forward_kernel( + x.data(), indptr_info, indptr_dims, dim, + stride, N, K, E, out.data()); + break; + } + default: + PD_THROW( + "function gather_csr_cpu_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out}; +} + +std::vector segment_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce); + +std::vector SegmentCsrForward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + if (x.is_cpu()) { + return segment_csr_cpu_forward(x, indptr, init, return_shape, reduce); + } else if (x.is_gpu()) { + return segment_csr_cuda_forward(x, indptr, init, return_shape, reduce); + } else { + PD_THROW("Unsupported device type for forward function of custom segment_csr operator."); + } +} + +std::vector gather_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape); + +std::vector GatherCsrForward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape) { + if (x.is_cpu()) { + return gather_csr_cpu_forward(x, indptr, init, return_shape); + } else if (x.is_gpu()) { + return gather_csr_cuda_forward(x, indptr, init, return_shape); + } else { + PD_THROW("Unsupported device type for forward function of custom gather_csr operator."); + } +} + +std::vector SegmentCsrBackward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& arg_out, + const paddle::Tensor& grad_out, + std::string reduce) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom segment_csr operator."); + } + if (reduce == "min" || reduce == "max") { + int64_t dim = indptr.shape().size() - 1; + auto x_shape = x.shape(); + x_shape[dim] += 1; + auto grad_x = paddle::zeros(x_shape, x.dtype(), x.place()); + paddle::experimental::put_along_axis_(grad_x, arg_out.get(), grad_out, dim); + grad_x = paddle::experimental::slice(grad_x, {dim}, {0}, {x_shape[dim] - 1}, {1}, {}); + return {grad_x}; + } else if (reduce == "mean") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + int64_t dim = indptr.shape().size() - 1; + if (grad_x.numel() > 0) { + grad_x = GatherCsrForward(grad_out, indptr, paddle::optional(grad_x), grad_x.shape())[0]; + auto indptr1 = paddle::experimental::slice(indptr, {dim}, {0}, {indptr.shape()[dim] - 1}, {1}, {}); + auto indptr2 = paddle::experimental::slice(indptr, {dim}, {1}, {indptr.shape()[dim]}, {1}, {}); + auto count = paddle::experimental::cast(indptr2 - indptr1, grad_x.dtype()); + auto sizes = count.shape(); + sizes[dim] = grad_x.shape()[dim]; + // sizes[dim] = *indptr.flatten()[-1].data(); + count = GatherCsrForward(count, indptr, paddle::optional(paddle::none), sizes)[0]; + for (auto i = 0; i < grad_out.shape().size() - indptr.shape().size(); i++) + paddle::experimental::unsqueeze_(count, {-1}); + paddle::experimental::divide_(grad_x, count); + } + return {grad_x}; + } else if (reduce == "sum") { + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); + paddle::Tensor grad_in = GatherCsrForward(grad_out, indptr, paddle::optional(grad_x), grad_x.shape())[0]; + return {grad_in}; + } +} + +std::vector GatherCsrBackward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::Tensor& grad_out) { + if (!x.is_cpu() && !x.is_gpu() ) { + PD_THROW("Unsupported device type for backward function of custom gather_csr operator."); + } + auto x_shape = x.shape(); + auto grad_x = paddle::empty(x_shape, x.dtype(), x.place()); + paddle::Tensor grad_in = SegmentCsrForward(grad_out, indptr, paddle::optional(grad_x), grad_x.shape(), "sum")[0]; + return {grad_in}; +} + +std::vector> SegmentCsrFWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const paddle::optional>& init_shape, + std::vector return_shape, + std::string reduce) { + return {return_shape, return_shape}; +} + +std::vector SegmentCsrFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::optional& init_dtype) { + return {x_dtype, indptr_dtype}; +} + +std::vector> SegmentCsrBWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const paddle::optional>& arg_out_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector SegmentCsrBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::optional& arg_out_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +std::vector> GatherCsrFWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const paddle::optional>& init_shape, + std::vector return_shape) { + return {return_shape}; +} + +std::vector GatherCsrFWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::optional& init_dtype) { + return {x_dtype}; +} + +std::vector> GatherCsrBWInferShape(const std::vector& x_shape, + const std::vector& indptr_shape, + const std::vector& grad_out_shape) { + return {x_shape}; +} + +std::vector GatherCsrBWInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& indptr_dtype, + const paddle::DataType& grad_out_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP(custom_segment_csr) + .Inputs({"X", "Indptr", paddle::Optional("Init")}) + .Outputs({"Out", paddle::Optional("ArgOut")}) + .Attrs({"return_shape: std::vector", + "reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCsrForward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCsrFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCsrFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_segment_csr) + .Inputs({"X", "Indptr", paddle::Optional("ArgOut"), paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"reduce: std::string"}) + .SetKernelFn(PD_KERNEL(SegmentCsrBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(SegmentCsrBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SegmentCsrBWInferDtype)); + +PD_BUILD_OP(custom_gather_csr) + .Inputs({"X", "Indptr", paddle::Optional("Init")}) + .Outputs({"Out"}) + .Attrs({"return_shape: std::vector"}) + .SetKernelFn(PD_KERNEL(GatherCsrForward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCsrFWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCsrFWInferDtype)); + +PD_BUILD_GRAD_OP(custom_gather_csr) + .Inputs({"X", "Indptr", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(GatherCsrBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(GatherCsrBWInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GatherCsrBWInferDtype)); \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/segment_csr.cu b/jointContribution/paddle_scatter/csrc/segment_csr.cu new file mode 100644 index 000000000..7c5e8a127 --- /dev/null +++ b/jointContribution/paddle_scatter/csrc/segment_csr.cu @@ -0,0 +1,382 @@ +#include "paddle/extension.h" + +#include +#include +#include + +#include "atomics.cuh" +#include "index_info.cuh" +#include "utils.cuh" + +#define THREADS 256 +#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS +#define FULL_MASK 0xffffffff + + +enum ReductionType { MIN, MAX, SUM, MEAN }; + +const std::map reduce2REDUCE = { + {"min", MIN}, {"max", MAX}, {"sum", SUM}, {"mean", MEAN} +}; + +template +__global__ void segment_csr_kernel(const data_t *x_data, + const TensorInfo indptr_info, + ReductionType reduce_type, + data_t *out_data, + index_t *arg_out_data, + size_t N, + size_t E) { + + // Each warp processes exactly `32/TB` rows and aggregates all row values + // via a parallel reduction. + + using MPType = typename MPTypeTrait::Type; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx & (TB - 1); + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + index_t row_start = __ldg(indptr_info.data + offset); + index_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + // init + data_t val; + if (reduce_type == MIN) + val = static_cast(std::numeric_limits::max()); + else if (reduce_type == MAX) + val = static_cast(std::numeric_limits::lowest()); + else if (reduce_type == SUM || reduce_type == MEAN) + val = static_cast(0); + + index_t arg, arg_tmp; + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (index_t x_idx = row_start + lane_idx; x_idx < row_end; + x_idx += TB) { + // update + auto cmp = x_data[offset + x_idx]; + if ((reduce_type == MIN && cmp < val) || + (reduce_type == MAX && cmp > val)) { + val = cmp; + arg = x_idx; + } else if (reduce_type == SUM || reduce_type == MEAN) { + val += cmp; + } + } + +#pragma unroll + for (int i = TB / 2; i > 0; i /= 2) { + // Parallel reduction inside a single warp. + if (reduce_type == MIN || reduce_type == MAX) + arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i); + // update + MPType cmp = SHFL_DOWN_SYNC(FULL_MASK, static_cast(val), i); + if ((reduce_type == MIN && cmp < static_cast(val)) || + (reduce_type == MAX && cmp > static_cast(val))) { + val = static_cast(cmp); + arg = arg_tmp; + } else if (reduce_type == SUM || reduce_type == MEAN) { + val += static_cast(cmp); + } + } + + if (lane_idx == 0) { + // write + auto count = row_end - row_start; + if (reduce_type == SUM) { + out_data[row_idx] = val; + } else if (reduce_type == MEAN) { + out_data[row_idx] = val / static_cast(count > 0 ? count : 1); + } else if (reduce_type == MIN || reduce_type == MAX) { + if (count > 0) { + out_data[row_idx] = val; + arg_out_data[row_idx] = arg; + } else { + out_data[row_idx] = static_cast(0); + } + } + } + } +} + +template +__global__ void segment_csr_broadcast_min_max_kernel(const data_t *x_data, + const TensorInfo indptr_info, + ReductionType reduce_type, + data_t *out_data, + index_t *arg_out_data, + size_t N, + size_t K, + size_t E) { + + // Each thread processes exactly one row. It turned out that is more + // efficient than using shared memory due to avoiding synchronization + // barriers. + + using MPType = typename MPTypeTrait::Type; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + index_t row_start = __ldg(indptr_info.data + offset); + index_t row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + // init + data_t val; + if (reduce_type == MIN) + val = static_cast(std::numeric_limits::max()); + else if (reduce_type == MAX) + val = static_cast(std::numeric_limits::lowest()); + else if (reduce_type == SUM || reduce_type == MEAN) + val = static_cast(0); + index_t arg; + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (index_t x_idx = row_start; x_idx < row_end; x_idx++) { + // update + auto cmp = x_data[offset + K * x_idx + lane_idx]; + if ((reduce_type == MIN && cmp < val) || + (reduce_type == MAX && cmp > val)) { + val = cmp; + arg = x_idx; + } else if (reduce_type == SUM || reduce_type == MEAN) { + val += cmp; + } + } + + // write + auto count = row_end - row_start; + if (reduce_type == SUM) { + out_data[thread_idx] = val; + } else if (reduce_type == MEAN) { + out_data[thread_idx] = val / static_cast(count > 0 ? count : 1); + } else if (reduce_type == MIN || reduce_type == MAX) { + if (count > 0) { + out_data[thread_idx] = val; + arg_out_data[thread_idx] = arg; + } else { + out_data[thread_idx] = static_cast(0); + } + } + } +} + +std::vector segment_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape, + const std::string& reduce) { + CHECK_CUDA(indptr); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + paddle::Tensor arg_out; + if (reduce == "min" || reduce == "max") { + arg_out = paddle::experimental::full_like(out, x_dims[dim], indptr.dtype(), indptr.place()); + } + + auto N = return_shape[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = out.numel() / N; + auto E = x_dims[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "segment_csr_cuda_forward_kernel", ([&] { + + const data_t* x_data = x.data(); + data_t* out_data = out.data(); + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + int* arg_out_data = (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr; + if (K == 1) { + segment_csr_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, K, E); + } + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + int64_t* arg_out_data = (reduce == "min" || reduce == "max") ? arg_out.data() : nullptr; + if (K == 1) { + segment_csr_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, E); + } else { + segment_csr_broadcast_min_max_kernel + <<>>( + x_data, indptr_info, reduce2REDUCE.at(reduce), + out_data, arg_out_data, N, K, E); + } + break; + } + default: + PD_THROW( + "function segment_csr_cuda_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out, arg_out}; +} + +template +__global__ void +gather_csr_kernel(const mp_t *src_data, + const TensorInfo indptr_info, + data_t *out_data, size_t N, size_t E) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / TB; + int lane_idx = thread_idx % TB; + + if (row_idx < N) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int row_start = __ldg(indptr_info.data + offset); + int row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + mp_t val = __ldg(src_data + row_idx); + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; + for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { + out_data[offset + out_idx] = static_cast(val); // "Mostly" coalesced. + } + } +} + +template +__global__ void gather_csr_broadcast_kernel( + const mp_t *src_data, + const TensorInfo indptr_info, + data_t *out_data, size_t N, size_t K, size_t E) { + + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = thread_idx / K; + int lane_idx = thread_idx % K; + + if (thread_idx < N * K) { + int offset = IndexPtrToOffset::get(row_idx, indptr_info); + int row_start = __ldg(indptr_info.data + offset); + int row_end = __ldg(indptr_info.data + offset + + indptr_info.strides[indptr_info.dims - 1]); + + mp_t val = src_data[thread_idx]; // Coalesced. + + offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; + for (int out_idx = row_start; out_idx < row_end; out_idx++) { + out_data[offset + K * out_idx + lane_idx] = static_cast(val); // "Mostly" coalesced. + } + } +} + + +std::vector gather_csr_cuda_forward(const paddle::Tensor& x, + const paddle::Tensor& indptr, + const paddle::optional& init, + const std::vector& return_shape) { + CHECK_CUDA(indptr); + if (init) + CHECK_CUDA(init.get()); + + auto x_dims = x.shape(); + auto indptr_dims = indptr.shape(); + CHECK_INPUT(x_dims.size() >= indptr_dims.size()); + auto dim = indptr_dims.size() - 1; + + // custom op input tensors are already contiguous + // x = x.contiguous(); + + paddle::Tensor out; + if (init) { + // paddle::Tensor init_contiguous = init->contiguous(); + // out = paddle::Tensor(init_contiguous); + out = paddle::Tensor(init.get()); + } + else { + out = paddle::empty(return_shape, x.dtype(), x.place()); + } + + auto N = x_dims[dim] * (indptr.numel() / indptr_dims[dim]); + auto K = x.numel() / N; + auto E = return_shape[dim]; + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_2_TYPES( + paddle::DataType::FLOAT16, paddle::DataType::BFLOAT16, + x.dtype(), "gather_csr_cuda_forward_kernel", ([&] { + using MPType = typename MPTypeTrait::Type; + paddle::Tensor x_mp; + if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) + x_mp = paddle::experimental::cast(x, paddle::DataType::FLOAT32); + else + x_mp = x; + const MPType* x_data = x_mp.data(); + data_t* out_data = out.data(); + switch(indptr.dtype()) { + case paddle::DataType::INT32: + { + auto indptr_info = getTensorInfo(indptr); + if (K == 1) + gather_csr_kernel + <<>>( + x_data, indptr_info, out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>( + x_data, indptr_info, out_data, N, K, E); + break; + } + case paddle::DataType::INT64: + { + auto indptr_info = getTensorInfo(indptr); + if (K == 1) + gather_csr_kernel + <<>>( + x_data, indptr_info, out_data, N, E); + else + gather_csr_broadcast_kernel + <<>>( + x_data, indptr_info, out_data, N, K, E); + break; + } + default: + PD_THROW( + "function gather_csr_cuda_forward_kernel is not implemented for the indptr data type `", + phi::DataTypeToString(indptr.dtype()), "`"); + } + })); + + return {out}; +} \ No newline at end of file diff --git a/jointContribution/paddle_scatter/csrc/utils.cuh b/jointContribution/paddle_scatter/csrc/utils.cuh index d5759541b..f5371e043 100644 --- a/jointContribution/paddle_scatter/csrc/utils.cuh +++ b/jointContribution/paddle_scatter/csrc/utils.cuh @@ -85,4 +85,4 @@ class MPTypeTrait { #else #define SHFL_UP_SYNC __shfl_up_sync #define SHFL_DOWN_SYNC __shfl_down_sync -#endif \ No newline at end of file +#endif diff --git a/jointContribution/paddle_scatter/scatter.py b/jointContribution/paddle_scatter/scatter.py index c513ecf52..7d9b5d231 100644 --- a/jointContribution/paddle_scatter/scatter.py +++ b/jointContribution/paddle_scatter/scatter.py @@ -2,16 +2,7 @@ from typing import Tuple import paddle -from paddle import assign -from paddle import divide -from paddle import floor_divide -from paddle import full -from paddle import full_like -from paddle import ones -from paddle import put_along_axis -from paddle import where -from paddle import zeros -from paddle_scatter_min_max_ops import custom_scatter_min_max +import paddle_scatter_ops from .utils import broadcast @@ -51,15 +42,19 @@ def scatter_sum( size[dim] = 0 else: size[dim] = int(index.max()) + 1 - arr = zeros(size, dtype=src.dtype) + arr = paddle.zeros(size, dtype=src.dtype) if src.numel() == 0: return arr - return put_along_axis(arr, indices=index, values=src, axis=dim, reduce="add") + return paddle.put_along_axis( + arr, indices=index, values=src, axis=dim, reduce="add" + ) else: if src.numel() == 0: return out - result = put_along_axis(out, indices=index, values=src, axis=dim, reduce="add") - assign(result, out) + result = paddle.put_along_axis( + out, indices=index, values=src, axis=dim, reduce="add" + ) + paddle.assign(result, out) return out @@ -127,15 +122,19 @@ def scatter_mul( size[dim] = 0 else: size[dim] = int(index.max()) + 1 - arr = ones(size, dtype=src.dtype) + arr = paddle.ones(size, dtype=src.dtype) if src.numel() == 0: return arr - return put_along_axis(arr, indices=index, values=src, axis=dim, reduce="mul") + return paddle.put_along_axis( + arr, indices=index, values=src, axis=dim, reduce="mul" + ) else: if src.numel() == 0: return out - result = put_along_axis(out, indices=index, values=src, axis=dim, reduce="mul") - assign(result, out) + result = paddle.put_along_axis( + out, indices=index, values=src, axis=dim, reduce="mul" + ) + paddle.assign(result, out) return out @@ -174,18 +173,18 @@ def scatter_mean( if index.dim() <= index_dim: index_dim = index.dim() - 1 - ones_tensor = ones(index.shape, dtype=src.dtype) + ones_tensor = paddle.ones(index.shape, dtype=src.dtype) tmp = scatter_sum(ones_tensor, index, index_dim, None, dim_size) - count = where(tmp < 1, full_like(tmp, 1), tmp, name="where") + count = paddle.where(tmp < 1, paddle.full_like(tmp, 1), tmp, name="where") count = broadcast(count, sums, dim) if sums.is_floating_point(): - result = divide(sums, count) + result = paddle.divide(sums, count) else: - result = floor_divide(sums, count) + result = paddle.floor_divide(sums, count) if out is None: return result else: - assign(result, out) + paddle.assign(result, out) return out @@ -230,20 +229,22 @@ def scatter_min( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_scatter_min_max(src, index, None, size, "min", dim) + return paddle_scatter_ops.custom_scatter_min_max( + src, index, None, size, "min", dim + ) else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_scatter_min_max( + result, arg_result = paddle_scatter_ops.custom_scatter_min_max( src, index, out, out.shape, "min", dim ) - assign(result, out) + paddle.assign(result, out) return out, arg_result @@ -288,20 +289,22 @@ def scatter_max( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_scatter_min_max(src, index, None, size, "max", dim) + return paddle_scatter_ops.custom_scatter_min_max( + src, index, None, size, "max", dim + ) else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_scatter_min_max( + result, arg_result = paddle_scatter_ops.custom_scatter_min_max( src, index, out, out.shape, "max", dim ) - assign(result, out) + paddle.assign(result, out) return out, arg_result diff --git a/jointContribution/paddle_scatter/segment_coo.py b/jointContribution/paddle_scatter/segment_coo.py index 75af63aff..91ea2e5aa 100644 --- a/jointContribution/paddle_scatter/segment_coo.py +++ b/jointContribution/paddle_scatter/segment_coo.py @@ -2,19 +2,7 @@ from typing import Tuple import paddle -from paddle import assign -from paddle import full -from paddle import slice -from paddle import take_along_axis -from paddle import to_tensor -from paddle import zeros -from paddle.geometric import segment_mean -from paddle.geometric import segment_sum -from paddle.nn.functional import pad -from paddle_scatter_min_max_ops import custom_segment_coo_min_max - -from .utils import broadcast -from .utils import transform_3d +import paddle_scatter_ops def segment_sum_coo( @@ -41,46 +29,42 @@ def segment_sum_coo( Returns: paddle.Tensor, the reduced tensor by sum reduction method. """ + src_shape = src.shape index_shape = index.shape dim = len(index_shape) - 1 - index = broadcast(index, src, dim) + # broadcast indptr to src + index_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + index = index.reshape(index_shape) + else: + index = index.expand(index_shape) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + tmp = index.index_select( + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + size[dim] = int(tmp) + 1 + if out is None: - out_size = src.shape - if dim_size is not None: - out_size[dim] = dim_size - elif index.numel() == 0: - out_size[dim] = 0 - else: - tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim - ).squeeze(dim) - tmp = tmp.max() if tmp.numel() > 1 else tmp - out_size[dim] = int(tmp) + 1 if src.numel() == 0: - return zeros(out_size, dtype=src.dtype) + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "sum")[0] else: if src.numel() == 0: return out - out_size = out.shape - init = out.clone() - tmp = zeros(out_size, dtype=src.dtype) - src_fatten = transform_3d(src, dim) - index_flatten = transform_3d(index, dim) - out_flatten = transform_3d(tmp, dim) - for i in range(src_fatten.shape[0]): - for j in range(src_fatten.shape[-1]): - result = segment_sum(src_fatten[i, :, j], index_flatten[i, :, j]) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - res = out_flatten.reshape(out_size) - if out is None: - return res - else: - res = res + init - assign(res, out) - return out + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result = paddle_scatter_ops.custom_segment_coo( + src, index, out, out.shape, "sum" + )[0] + paddle.assign(result, out) + return out def segment_add_coo( @@ -134,38 +118,42 @@ def segment_mean_coo( Returns: paddle.Tensor, the reduced tensor by mean reduction method. """ + src_shape = src.shape index_shape = index.shape dim = len(index_shape) - 1 - index = broadcast(index, src, dim) + # broadcast indptr to src + index_shape[:dim] = src_shape[:dim] + if src.numel() == 0: + index = index.reshape(index_shape) + else: + index = index.expand(index_shape) + size = src.shape + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + tmp = index.index_select( + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim + ).squeeze(dim) + tmp = tmp.max() if tmp.numel() > 1 else tmp + size[dim] = int(tmp) + 1 + if out is None: - out_size = src.shape - if dim_size is not None: - out_size[dim] = dim_size - elif index.numel() == 0: - out_size[dim] = 0 - else: - tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim - ).squeeze(dim) - tmp = tmp.max() if tmp.numel() > 1 else tmp - out_size[dim] = int(tmp) + 1 - out = zeros(out_size, dtype=src.dtype) + if src.numel() == 0: + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "mean")[0] else: - out_size = out.shape - if src.numel() == 0: + if src.numel() == 0: + return out + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result = paddle_scatter_ops.custom_segment_coo( + src, index, out, out.shape, "mean" + )[0] + paddle.assign(result, out) return out - src_fatten = transform_3d(src, dim) - index_flatten = transform_3d(index, dim) - out_flatten = transform_3d(out, dim) - for i in range(src_fatten.shape[0]): - for j in range(src_fatten.shape[-1]): - result = segment_mean(src_fatten[i, :, j], index_flatten[i, :, j]) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - assign(out_flatten.reshape(out_size), out) - return out def segment_min_coo( @@ -208,7 +196,7 @@ def segment_min_coo( size[dim] = 0 else: tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim ).squeeze(dim) tmp = tmp.max() if tmp.numel() > 1 else tmp size[dim] = int(tmp) + 1 @@ -216,20 +204,20 @@ def segment_min_coo( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_segment_coo_min_max(src, index, None, size, "min") + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "min") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_segment_coo_min_max( + result, arg_result = paddle_scatter_ops.custom_segment_coo( src, index, out, out.shape, "min" ) - assign(result, out) + paddle.assign(result, out) return out, arg_result @@ -273,7 +261,7 @@ def segment_max_coo( size[dim] = 0 else: tmp = index.index_select( - index=to_tensor([index.shape[dim] - 1]), axis=dim + index=paddle.to_tensor([index.shape[dim] - 1]), axis=dim ).squeeze(dim) tmp = tmp.max() if tmp.numel() > 1 else tmp size[dim] = int(tmp) + 1 @@ -281,20 +269,20 @@ def segment_max_coo( if out is None: if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], index.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], index.dtype), ) - return custom_segment_coo_min_max(src, index, None, size, "max") + return paddle_scatter_ops.custom_segment_coo(src, index, None, size, "max") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], index.dtype)) + return (out, paddle.full(size, src.shape[dim], index.dtype)) for i in range(len(size)): if i != dim: assert size[i] == out.shape[i] - result, arg_result = custom_segment_coo_min_max( + result, arg_result = paddle_scatter_ops.custom_segment_coo( src, index, out, out.shape, "max" ) - assign(result, out) + paddle.assign(result, out) return out, arg_result @@ -429,33 +417,20 @@ def gather_coo( dim = len(index_shape) - 1 src_shape = src.shape - # broadcast index's dimension to the same as src's - for _ in range(len(src_shape) - len(index_shape)): - index = index.unsqueeze(-1) - new_index_shape = src_shape.copy() - new_index_shape[dim] = index_shape[dim] - if src.numel() == 0: - index = index.reshape(new_index_shape) - else: - index = index.expand(new_index_shape) - if out is None: out_size = src_shape if index.numel() == 0: out_size[dim] = 0 - else: - out_size[dim] = index.shape[dim] - out = zeros(out_size, dtype=src.dtype) + return paddle.zeros(out_size, dtype=src.dtype) + out_size[dim] = index.shape[dim] + return paddle_scatter_ops.custom_gather_coo(src, index, None, out_size) else: + if src.numel() == 0: + return out out_size = out.shape - if src.numel() == 0: + for i in range(len(out_size)): + if i != dim: + assert src_shape[i] == out_size[i] + result = paddle_scatter_ops.custom_gather_coo(src, index, out, out_size) + paddle.assign(result, out) return out - - result = take_along_axis(src, index, dim, broadcast=False) - if out_size[dim] > result.shape[dim]: - padding = [0] * 2 * len(out_size) - padding[2 * dim + 1] = out_size[dim] - result.shape[dim] - result = pad(result, padding, mode="constant", value=0) - elif out_size[dim] < result.shape[dim]: - result = slice(result, [dim], 0, out_size[dim]) - return assign(result, out) diff --git a/jointContribution/paddle_scatter/segment_csr.py b/jointContribution/paddle_scatter/segment_csr.py index 6ecfbeb92..8d0e7eec8 100644 --- a/jointContribution/paddle_scatter/segment_csr.py +++ b/jointContribution/paddle_scatter/segment_csr.py @@ -2,17 +2,7 @@ from typing import Tuple import paddle -from paddle import arange -from paddle import assign -from paddle import full -from paddle import repeat_interleave -from paddle import zeros -from paddle.geometric import segment_mean -from paddle.geometric import segment_sum -from paddle_scatter_min_max_ops import custom_segment_csr_min_max - -from .utils import transform_2d -from .utils import transform_3d +import paddle_scatter_ops def segment_sum_csr( @@ -48,46 +38,19 @@ def segment_sum_csr( else: indptr = indptr.expand(indptr_shape) - num_seg = indptr_shape[dim] - 1 if out is None: - out_size = src_shape - if indptr.numel() == 0: - out_size[dim] = 0 - else: - out_size[dim] = num_seg + size = src.shape + size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: - return zeros(out_size, dtype=src.dtype) + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "sum")[0] else: - assert ( - out.shape[dim] == num_seg - ), "The (size of indptr at last dimension) must be\ - equal to the (size of out at the same dimension) + 1." if src.numel() == 0: return out - out_size = out.shape - tmp = zeros(out_size, dtype=src.dtype) - - repeats = indptr.diff(n=1, axis=dim) - assert ( - repeats.sum(axis=dim) == src.shape[dim] - ).all(), "The length of specified index by indptr shoud be\ - equal to the size of src at last dimension of indptr." - src_flatten = transform_3d(src, dim) - out_flatten = transform_3d(tmp, dim) - repeats_flatten = transform_2d(repeats, dim) - src_dim_indices = arange(num_seg) - for i in range(src_flatten.shape[0]): - for j in range(src_flatten.shape[-1]): - belongs_to = repeat_interleave(src_dim_indices, repeats_flatten[i], 0) - result = segment_sum(src_flatten[i, :, j], belongs_to) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - if out is None: - return out_flatten.reshape(out_size) - else: - assign(out_flatten.reshape(out_size), out) + result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "sum" + )[0] + paddle.assign(result, out) return out @@ -150,46 +113,19 @@ def segment_mean_csr( else: indptr = indptr.expand(indptr_shape) - num_seg = indptr_shape[dim] - 1 if out is None: - out_size = src_shape - if indptr.numel() == 0: - out_size[dim] = 0 - else: - out_size[dim] = num_seg + size = src.shape + size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: - return zeros(out_size, dtype=src.dtype) + return paddle.zeros(size, dtype=src.dtype) + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "mean")[0] else: - assert ( - out.shape[dim] == num_seg - ), "The (size of indptr at last dimension) must be\ - equal to the (size of out at the same dimension) + 1." if src.numel() == 0: return out - out_size = out.shape - tmp = zeros(out_size, dtype=src.dtype) - - repeats = indptr.diff(n=1, axis=dim) - assert ( - repeats.sum(axis=dim) == src.shape[dim] - ).all(), "The length of specified index by indptr shoud be\ - equal to the size of src at last dimension of indptr." - src_flatten = transform_3d(src, dim) - out_flatten = transform_3d(tmp, dim) - repeats_flatten = transform_2d(repeats, dim) - src_dim_indices = arange(num_seg) - for i in range(src_flatten.shape[0]): - for j in range(src_flatten.shape[-1]): - belongs_to = repeat_interleave(src_dim_indices, repeats_flatten[i], 0) - result = segment_mean(src_flatten[i, :, j], belongs_to) - if out_size[dim] >= len(result): - out_flatten[i, : len(result), j] = result - else: - out_flatten[i, :, j] = result[: out_size[dim]] - if out is None: - return out_flatten.reshape(out_size) - else: - assign(out_flatten.reshape(out_size), out) + result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "mean" + )[0] + paddle.assign(result, out) return out @@ -231,15 +167,17 @@ def segment_min_csr( size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], indptr.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], indptr.dtype), ) - return custom_segment_csr_min_max(src, indptr, size, "min") + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "min") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], indptr.dtype)) - result, arg_result = custom_segment_csr_min_max(src, indptr, out.shape, "min") - assign(result, out) + return (out, paddle.full(size, src.shape[dim], indptr.dtype)) + result, arg_result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "min" + ) + paddle.assign(result, out) return out, arg_result @@ -276,20 +214,25 @@ def segment_max_csr( else: indptr = indptr.expand(indptr_shape) + size = src.shape if out is None: - size = src.shape size[dim] = max(indptr_shape[dim] - 1, 0) if src.numel() == 0: return ( - zeros(size, dtype=src.dtype), - full(size, src.shape[dim], indptr.dtype), + paddle.zeros(size, dtype=src.dtype), + paddle.full(size, src.shape[dim], indptr.dtype), ) - return custom_segment_csr_min_max(src, indptr, size, "max") + return paddle_scatter_ops.custom_segment_csr(src, indptr, None, size, "max") else: if src.numel() == 0: - return (out, full(size, src.shape[dim], indptr.dtype)) - result, arg_result = custom_segment_csr_min_max(src, indptr, out.shape, "max") - assign(result, out) + return (out, paddle.full(size, src.shape[dim], indptr.dtype)) + for i in range(len(size)): + if i != dim: + assert size[i] == out.shape[i] + result, arg_result = paddle_scatter_ops.custom_segment_csr( + src, indptr, out, out.shape, "max" + ) + paddle.assign(result, out) return out, arg_result @@ -429,26 +372,16 @@ def gather_csr( out_size = src_shape if indptr.numel() == 0: out_size[dim] = 0 - else: - # refer to the original design in source cpp code - out_size[dim] = indptr.flatten()[-1] - out = zeros(out_size, dtype=src.dtype) + return paddle.zeros(out_size, dtype=src.dtype) + out_size[dim] = indptr.flatten()[-1] + return paddle_scatter_ops.custom_gather_csr(src, indptr, None, out_size) else: + if src.numel() == 0: + return out out_size = out.shape - if src.numel() == 0: + for i in range(len(out_size)): + if i != dim: + assert src_shape[i] == out_size[i] + result = paddle_scatter_ops.custom_gather_csr(src, indptr, out, out_size) + paddle.assign(result, out) return out - - repeats = indptr.diff(n=1, axis=dim) - src_flatten = transform_3d(src, dim) - out_flatten = transform_3d(out, dim) - repeats_flatten = transform_2d(repeats, dim) - for i in range(src_flatten.shape[0]): - for j in range(src_flatten.shape[-1]): - result = repeat_interleave(src_flatten[i, :, j], repeats_flatten[i], 0) - repeat_sum = repeats_flatten[i].sum() - if out_size[dim] >= repeat_sum: - out_flatten[i, :repeat_sum, j] = result[:repeat_sum] - else: - out_flatten[i, :, j] = result[: out_size[dim]] - assign(out_flatten.reshape(out_size), out) - return out diff --git a/jointContribution/paddle_scatter/setup.py b/jointContribution/paddle_scatter/setup.py index d78a28133..908c0b44e 100644 --- a/jointContribution/paddle_scatter/setup.py +++ b/jointContribution/paddle_scatter/setup.py @@ -32,10 +32,10 @@ def get_extensions(): setup( - name="paddle_scatter_min_max_ops", + name="paddle_scatter_ops", version="1.0", author="NKNaN", url="https://github.com/PaddlePaddle/PaddleScience/jointContribution/paddle_scatter", - description="Paddle extension of scatter and segment operators with min and max reduction methods", + description="Paddle extension of scatter and segment operators with min and max reduction methods, originally from https://github.com/rusty1s/pytorch_scatter", ext_modules=get_extensions(), ) diff --git a/jointContribution/paddle_scatter/testing.py b/jointContribution/paddle_scatter/testing.py index dfe23294c..1126fb8a6 100644 --- a/jointContribution/paddle_scatter/testing.py +++ b/jointContribution/paddle_scatter/testing.py @@ -4,13 +4,8 @@ reductions = ["sum", "add", "mean", "min", "max"] -dtypes = [ - # paddle.float16, paddle.bfloat16, - paddle.float32, - paddle.float64, - paddle.int32, - paddle.int64, -] +dtypes = [paddle.float32, paddle.float64, paddle.int32, paddle.int64] +dtypes_half = [paddle.float16, paddle.bfloat16] ind_dtypes = [paddle.int32, paddle.int64] grad_dtypes = [paddle.float32, paddle.float64] diff --git a/jointContribution/paddle_scatter/tests/composite/test_softmax.py b/jointContribution/paddle_scatter/tests/composite/test_softmax.py index 20c2c6bb8..28ae9410c 100644 --- a/jointContribution/paddle_scatter/tests/composite/test_softmax.py +++ b/jointContribution/paddle_scatter/tests/composite/test_softmax.py @@ -39,7 +39,7 @@ def test_log_softmax(): paddle.to_tensor([7], dtype=paddle.float32), axis=-1 ) out4 = paddle.nn.functional.log_softmax( - paddle.to_tensor([-1, float("-inf")]), axis=-1 + paddle.to_tensor([-1.0, float("-inf")]), axis=-1 ) expected = paddle.stack( diff --git a/jointContribution/paddle_scatter/tests/test_gather.py b/jointContribution/paddle_scatter/tests/test_gather.py index fbbb4df42..6fd8e40a3 100644 --- a/jointContribution/paddle_scatter/tests/test_gather.py +++ b/jointContribution/paddle_scatter/tests/test_gather.py @@ -5,6 +5,7 @@ from paddle_scatter import gather_coo from paddle_scatter import gather_csr from paddle_scatter.testing import dtypes +from paddle_scatter.testing import dtypes_half from paddle_scatter.testing import ind_dtypes from paddle_scatter.testing import places from paddle_scatter.testing import tensor @@ -72,6 +73,29 @@ def test_forward(test, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,dtype,ind_dtype", product(tests, dtypes_half, ind_dtypes) +) +def test_forward_half(test, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + out = gather_csr(src, indptr) + assert paddle.all(out == expected) + + out = gather_coo(src, index) + assert paddle.all(out == expected) + + @pytest.mark.parametrize("test,place", product(tests, places)) def test_backward(test, place): paddle.set_device(place) @@ -92,6 +116,32 @@ def test_backward(test, place): assert paddle.all(src.grad == exp_grad) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize("test", tests) +def test_backward_half(test): + paddle.set_device("gpu") + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + exp_grad = tensor(test["expected_grad"], paddle.float16) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = gather_csr(src, indptr) + out.backward() + assert paddle.all(src.grad == exp_grad) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = gather_coo(src, index) + out.backward() + assert paddle.all(src.grad == exp_grad) + + @pytest.mark.parametrize( "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) ) @@ -115,6 +165,35 @@ def test_out(test, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,dtype,ind_dtype", product(tests, dtypes_half, ind_dtypes) +) +def test_out_half(test, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + size = src.shape + size[index.dim() - 1] = index.shape[-1] + out = paddle.full(size, -2).astype(dtype) + + gather_csr(src, indptr, out) + assert paddle.all(out == expected) + + out.fill_(-2) + + gather_coo(src, index, out) + assert paddle.all(out == expected) + + @pytest.mark.parametrize( "test,dtype,ind_dtype,place", product(tests, dtypes, ind_dtypes, places) ) @@ -143,3 +222,39 @@ def test_non_contiguous(test, dtype, ind_dtype, place): out = gather_coo(src, index) assert paddle.all(out == expected) + + +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,dtype,ind_dtype", product(tests, dtypes_half, ind_dtypes) +) +def test_non_contiguous_half(test, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test["expected"], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + if indptr.dim() > 1: + shape = list(range(indptr.dim())) + shape[0], shape[1] = shape[1], shape[0] + indptr = indptr.transpose(shape).contiguous().transpose(shape) + + out = gather_csr(src, indptr) + assert paddle.all(out == expected) + + out = gather_coo(src, index) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/tests/test_multi_gpu.py b/jointContribution/paddle_scatter/tests/test_multi_gpu.py index e8dd74b6b..c5136d705 100644 --- a/jointContribution/paddle_scatter/tests/test_multi_gpu.py +++ b/jointContribution/paddle_scatter/tests/test_multi_gpu.py @@ -22,8 +22,8 @@ ] -@pytest.mark.skipif(not paddle.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(paddle.cuda.device_count() < 2, reason="No multiple GPUS") +@pytest.mark.skipif(paddle.device.cuda.device_count() == 0, reason="CUDA not available") +@pytest.mark.skipif(paddle.device.cuda.device_count() < 2, reason="No multiple GPUS") @pytest.mark.parametrize("test,reduce,dtype", product(tests, reductions, dtypes)) def test_forward(test, reduce, dtype): paddle.set_device("gpu:1") diff --git a/jointContribution/paddle_scatter/tests/test_scatter.py b/jointContribution/paddle_scatter/tests/test_scatter.py index 51c519222..f15fe3894 100644 --- a/jointContribution/paddle_scatter/tests/test_scatter.py +++ b/jointContribution/paddle_scatter/tests/test_scatter.py @@ -5,6 +5,7 @@ import paddle_scatter import pytest from paddle_scatter.testing import dtypes +from paddle_scatter.testing import dtypes_half from paddle_scatter.testing import ind_dtypes from paddle_scatter.testing import places from paddle_scatter.testing import reductions @@ -221,6 +222,33 @@ def test_forward(test, reduce, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_forward_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + fn = getattr(paddle_scatter, "scatter_" + reduce) + out = fn(src, index, dim) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + print(arg_out) + print(arg_expected) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + @pytest.mark.parametrize("test,reduce,place", product(tests, reductions, places)) def test_backward(test, reduce, place): paddle.set_device(place) @@ -237,6 +265,28 @@ def test_backward(test, reduce, place): ) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize("test,reduce", product(tests, reductions)) +def test_backward_half(test, reduce): + paddle.set_device("gpu") + index = tensor(test["index"], paddle.int64) + dim = test["dim"] + exp_grad = tensor(test[f"{reduce}_grad"], paddle.float16) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = paddle_scatter.scatter(src, index, dim, None, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), @@ -268,11 +318,81 @@ def test_out(test, reduce, dtype, ind_dtype, place): np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_out_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + out = paddle.full_like(expected, -2) + + getattr(paddle_scatter, "scatter_" + reduce)(src, index, dim, out) + + if reduce == "sum" or reduce == "add": + expected = expected - 2 + elif reduce == "mul": + expected = out # We can not really test this here. + elif reduce == "mean": + expected = out # We can not really test this here. + elif reduce == "min": + print(expected) + expected = expected.fill_(-2) + elif reduce == "max": + expected[expected == 0] = -2 + else: + raise ValueError + + np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) + + +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_non_contiguous(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + dim = test["dim"] + expected = tensor(test[reduce], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + + out = getattr(paddle_scatter, "scatter_" + reduce)(src, index, dim) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), ) -def test_non_contiguous(test, reduce, dtype, ind_dtype, place): +def test_non_contiguous_half(test, reduce, dtype, ind_dtype, place): paddle.set_device(place) src = tensor(test["src"], dtype) index = tensor(test["index"], ind_dtype) diff --git a/jointContribution/paddle_scatter/tests/test_segment.py b/jointContribution/paddle_scatter/tests/test_segment.py index 4acb788aa..1c3866c45 100644 --- a/jointContribution/paddle_scatter/tests/test_segment.py +++ b/jointContribution/paddle_scatter/tests/test_segment.py @@ -5,6 +5,7 @@ import paddle_scatter import pytest from paddle_scatter.testing import dtypes +from paddle_scatter.testing import dtypes_half from paddle_scatter.testing import ind_dtypes from paddle_scatter.testing import places from paddle_scatter.testing import reductions @@ -200,6 +201,39 @@ def test_forward(test, reduce, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_forward_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + fn = getattr(paddle_scatter, "segment_" + reduce + "_csr") + out = fn(src, indptr) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + fn = getattr(paddle_scatter, "segment_" + reduce + "_coo") + out = fn(src, index) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + @pytest.mark.parametrize("test,reduce,place", product(tests, reductions, places)) def test_backward(test, reduce, place): paddle.set_device(place) @@ -224,6 +258,36 @@ def test_backward(test, reduce, place): ) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize("test,reduce", product(tests, reductions)) +def test_backward_half(test, reduce): + paddle.set_device("gpu") + index = tensor(test["index"], paddle.int64) + indptr = tensor(test["indptr"], paddle.int64) + exp_grad = tensor(test[f"{reduce}_grad"], paddle.float16) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = paddle_scatter.segment_csr(src, indptr, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + src = tensor(test["src"], paddle.float16) + src.stop_gradient = False + out = paddle_scatter.segment_coo(src, index, None, None, reduce) + out.backward() + np.testing.assert_allclose( + src.grad.numpy(), exp_grad.numpy(), rtol=1e-05, atol=1e-06 + ) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), @@ -258,6 +322,45 @@ def test_out(test, reduce, dtype, ind_dtype, place): assert paddle.all(out == expected) +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_out_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + out = paddle.full_like(expected, -2) + + getattr(paddle_scatter, "segment_" + reduce + "_csr")(src, indptr, out) + np.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-05, atol=1e-06) + + out.fill_(-2) + + getattr(paddle_scatter, "segment_" + reduce + "_coo")(src, index, out) + + if reduce == "sum" or reduce == "add": + expected = expected - 2 + elif reduce == "mean": + expected = out # We can not really test this here. + elif reduce == "min": + expected = expected.fill_(-2) + elif reduce == "max": + expected[expected == 0] = -2 + else: + raise ValueError + + assert paddle.all(out == expected) + + @pytest.mark.parametrize( "test,reduce,dtype,ind_dtype,place", product(tests, reductions, dtypes, ind_dtypes, places), @@ -295,3 +398,47 @@ def test_non_contiguous(test, reduce, dtype, ind_dtype, place): arg_expected = tensor(test["arg_" + reduce], ind_dtype) assert paddle.all(arg_out == arg_expected) assert paddle.all(out == expected) + + +@pytest.mark.skipif( + not paddle.core.is_compiled_with_cuda() + and paddle.core.is_bfloat16_supported() + and paddle.core.is_float16_supported(), + reason="half dtype not available", +) +@pytest.mark.parametrize( + "test,reduce,dtype,ind_dtype", product(tests, reductions, dtypes_half, ind_dtypes) +) +def test_non_contiguous_half(test, reduce, dtype, ind_dtype): + paddle.set_device("gpu") + src = tensor(test["src"], dtype) + index = tensor(test["index"], ind_dtype) + indptr = tensor(test["indptr"], ind_dtype) + expected = tensor(test[reduce], dtype) + + if src.dim() > 1: + shape = list(range(src.dim())) + shape[0], shape[1] = shape[1], shape[0] + src = src.transpose(shape).contiguous().transpose(shape) + if index.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + index = index.transpose(shape).contiguous().transpose(shape) + if indptr.dim() > 1: + shape = list(range(index.dim())) + shape[0], shape[1] = shape[1], shape[0] + indptr = indptr.transpose(shape).contiguous().transpose(shape) + + out = getattr(paddle_scatter, "segment_" + reduce + "_csr")(src, indptr) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) + + out = getattr(paddle_scatter, "segment_" + reduce + "_coo")(src, index) + if reduce == "min" or reduce == "max": + out, arg_out = out + arg_expected = tensor(test["arg_" + reduce], ind_dtype) + assert paddle.all(arg_out == arg_expected) + assert paddle.all(out == expected) diff --git a/jointContribution/paddle_scatter/utils.py b/jointContribution/paddle_scatter/utils.py index b2db4874f..7356a011b 100644 --- a/jointContribution/paddle_scatter/utils.py +++ b/jointContribution/paddle_scatter/utils.py @@ -34,40 +34,3 @@ def broadcast(src: paddle.Tensor, other: paddle.Tensor, dim: int) -> paddle.Tens return src.reshape(other.shape) src = src.expand(other.shape) return src - - -def transform_3d(tensor: paddle.Tensor, dim: int) -> paddle.Tensor: - r"""Transform a tensor to [pre, dim, post] 3d dimensions. - - Args: - tensor (paddle.Tensor): The input tensor. - dim (int): The target dimension. - - Returns: - paddle.Tensor, the transformed tensor. - """ - if tensor.dim() == 1: - return tensor.unsqueeze(0).unsqueeze(-1) - elif tensor.dim() == 2: - if dim == 0: - return tensor.unsqueeze(0) - elif dim == 1: - return tensor.unsqueeze(-1) - else: - return tensor.flatten(0, dim - 1).flatten(2, -1) - - -def transform_2d(tensor: paddle.Tensor, dim: int) -> paddle.Tensor: - r"""Transform a tensor to [pre, dim] 2d dimensions. - - Args: - tensor (paddle.Tensor): The input tensor. - dim (int): The target dimension. - - Returns: - paddle.Tensor, the transformed tensor. - """ - if tensor.dim() == 1: - return tensor.unsqueeze(0) - else: - return tensor.flatten(0, dim - 1)