Skip to content

Commit

Permalink
add paddle-scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Nov 24, 2024
1 parent 4857525 commit 77fcc41
Show file tree
Hide file tree
Showing 31 changed files with 4,629 additions and 0 deletions.
53 changes: 53 additions & 0 deletions jointContribution/paddle_scatter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from .composite import scatter_log_softmax
from .composite import scatter_logsumexp
from .composite import scatter_softmax
from .composite import scatter_std
from .scatter import scatter
from .scatter import scatter_add
from .scatter import scatter_max
from .scatter import scatter_mean
from .scatter import scatter_min
from .scatter import scatter_mul
from .scatter import scatter_sum
from .segment_coo import gather_coo
from .segment_coo import segment_add_coo
from .segment_coo import segment_coo
from .segment_coo import segment_max_coo
from .segment_coo import segment_mean_coo
from .segment_coo import segment_min_coo
from .segment_coo import segment_sum_coo
from .segment_csr import gather_csr
from .segment_csr import segment_add_csr
from .segment_csr import segment_csr
from .segment_csr import segment_max_csr
from .segment_csr import segment_mean_csr
from .segment_csr import segment_min_csr
from .segment_csr import segment_sum_csr

__all__ = [
"scatter_sum",
"scatter_add",
"scatter_mul",
"scatter_mean",
"scatter_min",
"scatter_max",
"scatter",
"segment_sum_csr",
"segment_add_csr",
"segment_mean_csr",
"segment_min_csr",
"segment_max_csr",
"segment_csr",
"gather_csr",
"segment_sum_coo",
"segment_add_coo",
"segment_mean_coo",
"segment_min_coo",
"segment_max_coo",
"segment_coo",
"gather_coo",
"scatter_std",
"scatter_logsumexp",
"scatter_softmax",
"scatter_log_softmax",
]
11 changes: 11 additions & 0 deletions jointContribution/paddle_scatter/composite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .logsumexp import scatter_logsumexp
from .softmax import scatter_log_softmax
from .softmax import scatter_softmax
from .std import scatter_std

__all__ = [
"scatter_std",
"scatter_logsumexp",
"scatter_softmax",
"scatter_log_softmax",
]
82 changes: 82 additions & 0 deletions jointContribution/paddle_scatter/composite/logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional

import paddle
from paddle_scatter import scatter_max
from paddle_scatter import scatter_sum
from paddle_scatter.utils import broadcast


def scatter_logsumexp(
src: paddle.Tensor,
index: paddle.Tensor,
dim: int = -1,
out: Optional[paddle.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12,
) -> paddle.Tensor:
r"""Reduces all values from the `src` tensor into `out` at the
indices specified in the `index` tensor along a given axis`dim`,
the reduction method is logsumexp. (If dtype of `src` is int, output is still int.)
Args:
src (paddle.Tensor): The source tensor.
index (paddle.Tensor): The indices of elements to scatter. The dimension
of index should either be 1-D or :math:`i+1`-D. See Notes for more
details.
dim (int, optional): The axis along which to index. Default is -1.
out (paddle.Tensor|None, optional): The destination tensor. Default is None.
dim_size (int|None, optional): If `out` is not given, automatically create output
with size `dim_size` at dimension `dim`. If `dim_size` is not given,
a minimal sized output tensor according to `index.max() + 1` is returned.
Default is None.
eps (float, optional): Eplison factor added to the sum of exponent values during
computation in case they are zero. Default is 1e-12.
Returns:
paddle.Tensor, the reduced tensor by logsumexp reduction method.
"""
if not paddle.is_floating_point(src):
raise ValueError(
"`scatter_logsumexp` can only be computed over "
"tensors with floating point data types."
)

index = broadcast(index, src, dim)
eps = paddle.to_tensor(eps, dtype=src.dtype)

if out is not None:
dim_size = out.shape[dim]
else:
if dim_size is None:
dim_size = int(index.max()) + 1

size = src.shape
size[dim] = dim_size
max_value_per_index = paddle.full(
size,
fill_value=float("-inf"),
dtype=src.dtype,
)
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim)
recentered_score = src - max_per_src_element
recentered_score.masked_fill_(paddle.isnan(recentered_score), float("-inf"))

orig_out: Optional[paddle.Tensor] = None
if out is not None:
orig_out = out.clone()
res = out.subtract(max_value_per_index).exp()

sum_per_index = scatter_sum(recentered_score.exp(), index, dim, res, dim_size)
else:
sum_per_index = scatter_sum(recentered_score.exp(), index, dim, None, dim_size)

res = sum_per_index.add(eps).log().add(max_value_per_index)

if orig_out is None:
return res.nan_to_num_(neginf=0.0)

mask = ~res.isfinite()
res[mask] = orig_out[mask]
paddle.assign(res, out)
return out
99 changes: 99 additions & 0 deletions jointContribution/paddle_scatter/composite/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

import paddle
from paddle_scatter import scatter_max
from paddle_scatter import scatter_sum
from paddle_scatter.utils import broadcast


def scatter_softmax(
src: paddle.Tensor,
index: paddle.Tensor,
dim: int = -1,
dim_size: Optional[int] = None,
) -> paddle.Tensor:
r"""Reduces all values from the `src` tensor into `out` at the
indices specified in the `index` tensor along a given axis`dim`,
the reduction method is softmax. (If dtype of `src` is int, output is still int.)
Args:
src (paddle.Tensor): The source tensor.
index (paddle.Tensor): The indices of elements to scatter. The dimension
of index should either be 1-D or :math:`i+1`-D. See Notes for more
details.
dim (int, optional): The axis along which to index. Default is -1.
dim_size (int|None, optional): If `out` is not given, automatically create output
with size `dim_size` at dimension `dim`. If `dim_size` is not given,
a minimal sized output tensor according to `index.max() + 1` is returned.
Default is None.
Returns:
paddle.Tensor, the reduced tensor by softmax reduction method.
"""
if not paddle.is_floating_point(src):
raise ValueError(
"`scatter_softmax` can only be computed over tensors "
"with floating point data types."
)

index = broadcast(index, src, dim)

max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim)

recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()

sum_per_index = scatter_sum(recentered_scores_exp, index, dim, dim_size=dim_size)
normalizing_constants = sum_per_index.take_along_axis(indices=index, axis=dim)

return recentered_scores_exp.divide(normalizing_constants)


def scatter_log_softmax(
src: paddle.Tensor,
index: paddle.Tensor,
dim: int = -1,
eps: float = 1e-12,
dim_size: Optional[int] = None,
) -> paddle.Tensor:
r"""Reduces all values from the `src` tensor into `out` at the
indices specified in the `index` tensor along a given axis`dim`,
the reduction method is log_softmax. (If dtype of `src` is int, output is still int.)
Args:
src (paddle.Tensor): The source tensor.
index (paddle.Tensor): The indices of elements to scatter. The dimension
of index should either be 1-D or :math:`i+1`-D. See Notes for more
details.
dim (int, optional): The axis along which to index. Default is -1.
eps (float, optional): Eplison factor added to the normalizing constants during
computation in case they are zero. Default is 1e-12.
dim_size (int|None, optional): If `out` is not given, automatically create output
with size `dim_size` at dimension `dim`. If `dim_size` is not given,
a minimal sized output tensor according to `index.max() + 1` is returned.
Default is None.
Returns:
paddle.Tensor, the reduced tensor by log_softmax reduction method.
"""
if not paddle.is_floating_point(src):
raise ValueError(
"`scatter_log_softmax` can only be computed over "
"tensors with floating point data types."
)

index = broadcast(index, src, dim)
eps = paddle.to_tensor(eps, dtype=src.dtype)

max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim)

recentered_scores = src - max_per_src_element

sum_per_index = scatter_sum(recentered_scores.exp(), index, dim, dim_size=dim_size)
normalizing_constants = (
sum_per_index.add(eps).log().take_along_axis(indices=index, axis=dim)
)

return recentered_scores.subtract(normalizing_constants)
67 changes: 67 additions & 0 deletions jointContribution/paddle_scatter/composite/std.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Optional

import paddle
from paddle_scatter import scatter_sum
from paddle_scatter.utils import broadcast


def scatter_std(
src: paddle.Tensor,
index: paddle.Tensor,
dim: int = -1,
out: Optional[paddle.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> paddle.Tensor:
r"""Reduces all values from the `src` tensor into `out` at the
indices specified in the `index` tensor along a given axis`dim`,
the reduction method is std. (If dtype of `src` is int, output is still int.)
Args:
src (paddle.Tensor): The source tensor.
index (paddle.Tensor): The indices of elements to scatter. The dimension
of index should either be 1-D or :math:`i+1`-D. See Notes for more
details.
dim (int, optional): The axis along which to index. Default is -1.
out (paddle.Tensor|None, optional): The destination tensor. Default is None.
dim_size (int|None, optional): If `out` is not given, automatically create output
with size `dim_size` at dimension `dim`. If `dim_size` is not given,
a minimal sized output tensor according to `index.max() + 1` is returned.
Default is None.
unbiased (bool, optional): Indicate whether to calculate biased std (divide by n)
or unbiased std (divide by n-1). Default is True.
Returns:
paddle.Tensor, the reduced tensor by std reduction method.
"""
if out is not None:
dim_size = out.shape[dim]

if dim < 0:
dim = src.dim() + dim

count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1

ones = paddle.ones(index.shape, dtype=src.dtype)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clip(1)
mean = tmp.divide(count)

var = src - mean.take_along_axis(indices=index, axis=dim)
var = var * var
res = scatter_sum(var, index, dim, out, dim_size)

if unbiased:
count = count.subtract(paddle.to_tensor(1, dtype=src.dtype)).clip(1)
res = res.divide(count + 1e-6).sqrt()

if out is not None:
paddle.assign(res, out)
return out
else:
return res
Loading

0 comments on commit 77fcc41

Please sign in to comment.