-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Paddle Toolkit Development Competition No.4】 Paddle 适配 torch-scatter #1028
Open
NKNaN
wants to merge
2
commits into
PaddlePaddle:develop
Choose a base branch
from
NKNaN:pscatter
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+6,806
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from .composite import scatter_log_softmax | ||
from .composite import scatter_logsumexp | ||
from .composite import scatter_softmax | ||
from .composite import scatter_std | ||
from .scatter import scatter | ||
from .scatter import scatter_add | ||
from .scatter import scatter_max | ||
from .scatter import scatter_mean | ||
from .scatter import scatter_min | ||
from .scatter import scatter_mul | ||
from .scatter import scatter_sum | ||
from .segment_coo import gather_coo | ||
from .segment_coo import segment_add_coo | ||
from .segment_coo import segment_coo | ||
from .segment_coo import segment_max_coo | ||
from .segment_coo import segment_mean_coo | ||
from .segment_coo import segment_min_coo | ||
from .segment_coo import segment_sum_coo | ||
from .segment_csr import gather_csr | ||
from .segment_csr import segment_add_csr | ||
from .segment_csr import segment_csr | ||
from .segment_csr import segment_max_csr | ||
from .segment_csr import segment_mean_csr | ||
from .segment_csr import segment_min_csr | ||
from .segment_csr import segment_sum_csr | ||
|
||
__all__ = [ | ||
"scatter_sum", | ||
"scatter_add", | ||
"scatter_mul", | ||
"scatter_mean", | ||
"scatter_min", | ||
"scatter_max", | ||
"scatter", | ||
"segment_sum_csr", | ||
"segment_add_csr", | ||
"segment_mean_csr", | ||
"segment_min_csr", | ||
"segment_max_csr", | ||
"segment_csr", | ||
"gather_csr", | ||
"segment_sum_coo", | ||
"segment_add_coo", | ||
"segment_mean_coo", | ||
"segment_min_coo", | ||
"segment_max_coo", | ||
"segment_coo", | ||
"gather_coo", | ||
"scatter_std", | ||
"scatter_logsumexp", | ||
"scatter_softmax", | ||
"scatter_log_softmax", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .logsumexp import scatter_logsumexp | ||
from .softmax import scatter_log_softmax | ||
from .softmax import scatter_softmax | ||
from .std import scatter_std | ||
|
||
__all__ = [ | ||
"scatter_std", | ||
"scatter_logsumexp", | ||
"scatter_softmax", | ||
"scatter_log_softmax", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Optional | ||
|
||
import paddle | ||
from paddle_scatter import scatter_max | ||
from paddle_scatter import scatter_sum | ||
from paddle_scatter.utils import broadcast | ||
|
||
|
||
def scatter_logsumexp( | ||
src: paddle.Tensor, | ||
index: paddle.Tensor, | ||
dim: int = -1, | ||
out: Optional[paddle.Tensor] = None, | ||
dim_size: Optional[int] = None, | ||
eps: float = 1e-12, | ||
) -> paddle.Tensor: | ||
r"""Reduces all values from the `src` tensor into `out` at the | ||
indices specified in the `index` tensor along a given axis`dim`, | ||
the reduction method is logsumexp. (If dtype of `src` is int, output is still int.) | ||
|
||
Args: | ||
src (paddle.Tensor): The source tensor. | ||
index (paddle.Tensor): The indices of elements to scatter. The dimension | ||
of index should either be 1-D or :math:`i+1`-D. See Notes for more | ||
details. | ||
dim (int, optional): The axis along which to index. Default is -1. | ||
out (paddle.Tensor|None, optional): The destination tensor. Default is None. | ||
dim_size (int|None, optional): If `out` is not given, automatically create output | ||
with size `dim_size` at dimension `dim`. If `dim_size` is not given, | ||
a minimal sized output tensor according to `index.max() + 1` is returned. | ||
Default is None. | ||
eps (float, optional): Eplison factor added to the sum of exponent values during | ||
computation in case they are zero. Default is 1e-12. | ||
|
||
Returns: | ||
paddle.Tensor, the reduced tensor by logsumexp reduction method. | ||
""" | ||
if not paddle.is_floating_point(src): | ||
raise ValueError( | ||
"`scatter_logsumexp` can only be computed over " | ||
"tensors with floating point data types." | ||
) | ||
|
||
index = broadcast(index, src, dim) | ||
eps = paddle.full([], eps, dtype=src.dtype) | ||
|
||
if out is not None: | ||
dim_size = out.shape[dim] | ||
else: | ||
if dim_size is None: | ||
dim_size = int(index.max()) + 1 | ||
|
||
size = src.shape | ||
size[dim] = dim_size | ||
max_value_per_index = paddle.full( | ||
size, | ||
fill_value=float("-inf"), | ||
dtype=src.dtype, | ||
) | ||
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] | ||
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) | ||
recentered_score = src - max_per_src_element | ||
recentered_score.masked_fill_(paddle.isnan(recentered_score), float("-inf")) | ||
|
||
orig_out: Optional[paddle.Tensor] = None | ||
if out is not None: | ||
orig_out = out.clone() | ||
res = out.subtract(max_value_per_index).exp() | ||
|
||
sum_per_index = scatter_sum(recentered_score.exp(), index, dim, res, dim_size) | ||
else: | ||
sum_per_index = scatter_sum(recentered_score.exp(), index, dim, None, dim_size) | ||
|
||
res = sum_per_index.add(eps).log().add(max_value_per_index) | ||
|
||
if orig_out is None: | ||
return res.nan_to_num_(neginf=0.0) | ||
|
||
mask = ~res.isfinite() | ||
res[mask] = orig_out[mask] | ||
paddle.assign(res, out) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Optional | ||
|
||
import paddle | ||
from paddle_scatter import scatter_max | ||
from paddle_scatter import scatter_sum | ||
from paddle_scatter.utils import broadcast | ||
|
||
|
||
def scatter_softmax( | ||
src: paddle.Tensor, | ||
index: paddle.Tensor, | ||
dim: int = -1, | ||
dim_size: Optional[int] = None, | ||
) -> paddle.Tensor: | ||
r"""Reduces all values from the `src` tensor into `out` at the | ||
indices specified in the `index` tensor along a given axis`dim`, | ||
the reduction method is softmax. (If dtype of `src` is int, output is still int.) | ||
|
||
Args: | ||
src (paddle.Tensor): The source tensor. | ||
index (paddle.Tensor): The indices of elements to scatter. The dimension | ||
of index should either be 1-D or :math:`i+1`-D. See Notes for more | ||
details. | ||
dim (int, optional): The axis along which to index. Default is -1. | ||
dim_size (int|None, optional): If `out` is not given, automatically create output | ||
with size `dim_size` at dimension `dim`. If `dim_size` is not given, | ||
a minimal sized output tensor according to `index.max() + 1` is returned. | ||
Default is None. | ||
|
||
Returns: | ||
paddle.Tensor, the reduced tensor by softmax reduction method. | ||
""" | ||
if not paddle.is_floating_point(src): | ||
raise ValueError( | ||
"`scatter_softmax` can only be computed over tensors " | ||
"with floating point data types." | ||
) | ||
|
||
index = broadcast(index, src, dim) | ||
|
||
max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] | ||
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) | ||
|
||
recentered_scores = src - max_per_src_element | ||
recentered_scores_exp = recentered_scores.exp() | ||
|
||
sum_per_index = scatter_sum(recentered_scores_exp, index, dim, dim_size=dim_size) | ||
normalizing_constants = sum_per_index.take_along_axis(indices=index, axis=dim) | ||
|
||
return recentered_scores_exp.divide(normalizing_constants) | ||
|
||
|
||
def scatter_log_softmax( | ||
src: paddle.Tensor, | ||
index: paddle.Tensor, | ||
dim: int = -1, | ||
eps: float = 1e-12, | ||
dim_size: Optional[int] = None, | ||
) -> paddle.Tensor: | ||
r"""Reduces all values from the `src` tensor into `out` at the | ||
indices specified in the `index` tensor along a given axis`dim`, | ||
the reduction method is log_softmax. (If dtype of `src` is int, output is still int.) | ||
|
||
Args: | ||
src (paddle.Tensor): The source tensor. | ||
index (paddle.Tensor): The indices of elements to scatter. The dimension | ||
of index should either be 1-D or :math:`i+1`-D. See Notes for more | ||
details. | ||
dim (int, optional): The axis along which to index. Default is -1. | ||
eps (float, optional): Eplison factor added to the normalizing constants during | ||
computation in case they are zero. Default is 1e-12. | ||
dim_size (int|None, optional): If `out` is not given, automatically create output | ||
with size `dim_size` at dimension `dim`. If `dim_size` is not given, | ||
a minimal sized output tensor according to `index.max() + 1` is returned. | ||
Default is None. | ||
|
||
Returns: | ||
paddle.Tensor, the reduced tensor by log_softmax reduction method. | ||
""" | ||
if not paddle.is_floating_point(src): | ||
raise ValueError( | ||
"`scatter_log_softmax` can only be computed over " | ||
"tensors with floating point data types." | ||
) | ||
|
||
index = broadcast(index, src, dim) | ||
eps = paddle.full([], eps, dtype=src.dtype) | ||
|
||
max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] | ||
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim) | ||
|
||
recentered_scores = src - max_per_src_element | ||
|
||
sum_per_index = scatter_sum(recentered_scores.exp(), index, dim, dim_size=dim_size) | ||
normalizing_constants = ( | ||
sum_per_index.add(eps).log().take_along_axis(indices=index, axis=dim) | ||
) | ||
|
||
return recentered_scores.subtract(normalizing_constants) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Optional | ||
|
||
import paddle | ||
from paddle_scatter import scatter_sum | ||
from paddle_scatter.utils import broadcast | ||
|
||
|
||
def scatter_std( | ||
src: paddle.Tensor, | ||
index: paddle.Tensor, | ||
dim: int = -1, | ||
out: Optional[paddle.Tensor] = None, | ||
dim_size: Optional[int] = None, | ||
unbiased: bool = True, | ||
) -> paddle.Tensor: | ||
r"""Reduces all values from the `src` tensor into `out` at the | ||
indices specified in the `index` tensor along a given axis`dim`, | ||
the reduction method is std. (If dtype of `src` is int, output is still int.) | ||
|
||
Args: | ||
src (paddle.Tensor): The source tensor. | ||
index (paddle.Tensor): The indices of elements to scatter. The dimension | ||
of index should either be 1-D or :math:`i+1`-D. See Notes for more | ||
details. | ||
dim (int, optional): The axis along which to index. Default is -1. | ||
out (paddle.Tensor|None, optional): The destination tensor. Default is None. | ||
dim_size (int|None, optional): If `out` is not given, automatically create output | ||
with size `dim_size` at dimension `dim`. If `dim_size` is not given, | ||
a minimal sized output tensor according to `index.max() + 1` is returned. | ||
Default is None. | ||
unbiased (bool, optional): Indicate whether to calculate biased std (divide by n) | ||
or unbiased std (divide by n-1). Default is True. | ||
|
||
Returns: | ||
paddle.Tensor, the reduced tensor by std reduction method. | ||
""" | ||
if out is not None: | ||
dim_size = out.shape[dim] | ||
|
||
if dim < 0: | ||
dim = src.dim() + dim | ||
|
||
count_dim = dim | ||
if index.dim() <= dim: | ||
count_dim = index.dim() - 1 | ||
|
||
ones = paddle.ones(index.shape, dtype=src.dtype) | ||
count = scatter_sum(ones, index, count_dim, dim_size=dim_size) | ||
|
||
index = broadcast(index, src, dim) | ||
tmp = scatter_sum(src, index, dim, dim_size=dim_size) | ||
count = broadcast(count, tmp, dim).clip(1) | ||
mean = tmp.divide(count) | ||
|
||
var = src - mean.take_along_axis(indices=index, axis=dim) | ||
var = var * var | ||
res = scatter_sum(var, index, dim, out, dim_size) | ||
|
||
if unbiased: | ||
count = count.subtract(paddle.full([], 1, dtype=src.dtype)).clip(1) | ||
res = res.divide(count + 1e-6).sqrt() | ||
|
||
if out is not None: | ||
paddle.assign(res, out) | ||
return out | ||
else: | ||
return res | ||
Comment on lines
+63
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里这样写是对的,如果out是None,则返回res |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有个问题,这里如果out为None的话,是不是不正确? paddle.assign(res, out)之后,out应该仍然是None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out为None的话 orig_out 也为None,会走上面那个分支return