Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Dec 2, 2024
1 parent 8c5351c commit 3993fb4
Show file tree
Hide file tree
Showing 23 changed files with 2,501 additions and 324 deletions.
2 changes: 1 addition & 1 deletion jointContribution/paddle_scatter/composite/logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def scatter_logsumexp(
)

index = broadcast(index, src, dim)
eps = paddle.to_tensor(eps, dtype=src.dtype)
eps = paddle.full([], eps, dtype=src.dtype)

if out is not None:
dim_size = out.shape[dim]
Expand Down
2 changes: 1 addition & 1 deletion jointContribution/paddle_scatter/composite/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def scatter_log_softmax(
)

index = broadcast(index, src, dim)
eps = paddle.to_tensor(eps, dtype=src.dtype)
eps = paddle.full([], eps, dtype=src.dtype)

max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.take_along_axis(indices=index, axis=dim)
Expand Down
2 changes: 1 addition & 1 deletion jointContribution/paddle_scatter/composite/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def scatter_std(
res = scatter_sum(var, index, dim, out, dim_size)

if unbiased:
count = count.subtract(paddle.to_tensor(1, dtype=src.dtype)).clip(1)
count = count.subtract(paddle.full([], 1, dtype=src.dtype)).clip(1)
res = res.divide(count + 1e-6).sqrt()

if out is not None:
Expand Down
31 changes: 31 additions & 0 deletions jointContribution/paddle_scatter/csrc/atomics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,34 @@ static inline __device__ void atomMin(float *address, float val) {
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}

#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
7 changes: 2 additions & 5 deletions jointContribution/paddle_scatter/csrc/index_info.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

#include "paddle/extension.h"

#define MAX_TENSORINFO_DIMS 25
#define MAX_TENSORINFO_DIMS 7

template <typename T> struct TensorInfo {
TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS],
int st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25.");
PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 7.");

for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
Expand All @@ -30,11 +30,8 @@ TensorInfo<T> getTensorInfo(const paddle::Tensor &tensor) {
int strides[MAX_TENSORINFO_DIMS];

int dims = tensor.shape().size();
// int stride_i = 1;
for (int i = dims - 1; i >= 0; --i) {
sizes[i] = tensor.shape()[i];
// strides[i] = stride_i;
// stride_i *= sizes[i];
sizes[i] = tensor.strides()[i];
}

Expand Down
7 changes: 2 additions & 5 deletions jointContribution/paddle_scatter/csrc/index_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

#include "paddle/extension.h"

#define MAX_TENSORINFO_DIMS 25
#define MAX_TENSORINFO_DIMS 7

template <typename T> struct TensorInfo {
TensorInfo(const T *p, int dim, int sz[MAX_TENSORINFO_DIMS],
int st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 25.");
PD_CHECK(dims < MAX_TENSORINFO_DIMS, "Input dims should be smaller than 7.");

for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
Expand All @@ -29,11 +29,8 @@ TensorInfo<T> getTensorInfo(const paddle::Tensor &tensor) {
int strides[MAX_TENSORINFO_DIMS];

int dims = tensor.shape().size();
// int stride_i = 1;
for (int i = dims - 1; i >= 0; --i) {
sizes[i] = tensor.shape()[i];
// strides[i] = stride_i;
// stride_i *= sizes[i];
strides[i] = tensor.strides()[i];
}

Expand Down
2 changes: 1 addition & 1 deletion jointContribution/paddle_scatter/csrc/scatter_min_max.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ std::vector<paddle::Tensor> scatter_min_max_cuda_forward(const paddle::Tensor& x
using MPType = typename MPTypeTrait<data_t>::Type;
paddle::Tensor out_mp;
if (x.dtype() == paddle::DataType::FLOAT16 || x.dtype() == paddle::DataType::BFLOAT16) {
out_mp = paddle::empty(return_shape, paddle::DataType::FLOAT32, x.place());
out_mp = paddle::experimental::cast(out, paddle::DataType::FLOAT32);
} else {
out_mp = out;
}
Expand Down
Loading

0 comments on commit 3993fb4

Please sign in to comment.