Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Conv3D ONNX frontend and AntaresIR; Format code style #451

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,12 @@ LanguageUnit_p cuda::get_cudnn_convolution_descriptor(const Shape& padding,
<< "window_dilation_strides_int, CUDNN_CROSS_CORRELATION, " << data_type << "));\n";
}

if(type == nnfusion::element::f16){
if (type == nnfusion::element::f16)
{
// half precision, use tensor core
lu << "CUDNN_SAFE_CALL(cudnnSetConvolutionMathType(" << desc << ", "
<< "CUDNN_TENSOR_OP_MATH"
<< "));\n";
<< "CUDNN_TENSOR_OP_MATH"
<< "));\n";
}

return _lu;
Expand Down
9 changes: 6 additions & 3 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ namespace nnfusion
@hCublas@, @transA@, @transB@, @m@, @n@, @k@,
&alpha, input1, @lda@, @stride_a@, input0, @ldb@, @stride_b@,
&beta, output0, @ldc@, @stride_c@, @batch@));
)" :
R"(
)"
:
R"(
static const float alpha = @alpha@F, beta = @beta@F;
// if (!@hCublas@)
// CUBLAS_SAFE_CALL(@api_create@(&@hCublas@));
Expand All @@ -116,7 +117,9 @@ namespace nnfusion
{
{"hCublas", "cublas_handle"},
{"api_create", "cublasCreate"},
{"api_exec", dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" : "cublasSgemmStridedBatched"},
{"api_exec",
dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched"
: "cublasSgemmStridedBatched"},
{"transA", transB ? "CUBLAS_OP_T" : "CUBLAS_OP_N"},
{"transB", transA ? "CUBLAS_OP_T" : "CUBLAS_OP_N"},
{"alpha", alpha},
Expand Down
10 changes: 7 additions & 3 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ LanguageUnit_p cuda::BatchNormNCHW::emit_function_body()
/*
* todo: may have better solution, details in https://github.com/microsoft/nnfusion/issues/434
* */
if(dtype == nnfusion::element::f16){
lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], __hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half("
if (dtype == nnfusion::element::f16)
{
lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], "
"__hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half("
<< epsilon << "), input4[c_id]))));\n";
}else{
}
else
{
lu << "(input1[c_id] + (input0[c_id] * "
"(input2[st + i] - input3[c_id]) / sqrtf("
<< epsilon << " + input4[c_id])));\n";
Expand Down
280 changes: 142 additions & 138 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,161 +207,165 @@ LanguageUnit_p cuda::Dot::emit_function_body()
else if (dtype == element::f16)
{
// case 1: Scalar * Tensor
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape);
size_t count = nnfusion::shape_size(second);
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape);
size_t count = nnfusion::shape_size(second);

string firstarg = (arg0_shape.empty() ? "input1" : "input0");
string secondarg = (arg0_shape.empty() ? "input0" : "input1");
string firstarg = (arg0_shape.empty() ? "input1" : "input0");
string secondarg = (arg0_shape.empty() ? "input0" : "input1");

lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";

lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0`
lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n";
}
lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count
<< ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0`
lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count
<< "));\n";
}
// // case 2: 1d Dot
else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes))
{
for (int i = 0; i < arg0_shape.size(); i++)
{
if (arg0_shape[i] != arg1_shape[i])
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, arg1_shape};

NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
}

size_t count = nnfusion::shape_size(arg0_shape);
lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";

lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count
<< ", static_cast<const float*>(input0), 1, static_cast<const float*>(input1), 1, "
"static_cast<float*>(output0)));\n";
}
// // matrix * vector
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1))
{
lu << "const float alpha = 1.0;\n const float beta = 0;\n";
lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, ";
if (trans_A)
lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", ";
else
lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", ";
lu << " &alpha,"
<< " static_cast<const float*>(input0)," << arg0_shape[1] << ", "
<< " static_cast<const float*>(input1),"
<< " 1,"
<< " &beta,"
<< " static_cast<float*>(output0),"
<< " 1));\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) &&
(trans_A || trans_B))
{
int m = trans_B ? arg1_shape[0] : arg1_shape[1];
int n = trans_A ? arg0_shape[1] : arg0_shape[0];
int k = trans_A ? arg0_shape[0] : arg0_shape[1];

lu << "const half alpha = 1.0;\nconst half beta = 0;\n";

lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle,"
<< (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,")
<< (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << ","
<< " " << n << ","
<< " " << k << ","
<< " &alpha,"
<< " static_cast<const half*>(input1),"
<< " " << arg1_shape[1] << ","
<< " static_cast<const half*>(input0),"
<< " " << arg0_shape[1] << ","
<< " &beta,"
<< " static_cast<half*>(output0),"
<< " " << m << "));\n";
} else {
size_t axes_for_m_count = arg0_shape.size() - reduction_axes;
size_t axes_for_n_count = arg1_shape.size() - reduction_axes;
size_t axes_for_k_count = reduction_axes;
size_t m = 1;
size_t n = 1;
size_t k = 1;

// check if input and output size correct
// check and calculate k for arg0 and arg1
size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k
size_t arg1_k_idx = 0; // first axe in arg1 for k

for (size_t i = 0; i < axes_for_k_count; i++)
else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes))
{
k *= arg0_shape[arg0_k_idx];
if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++])
for (int i = 0; i < arg0_shape.size(); i++)
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, arg1_shape};
if (arg0_shape[i] != arg1_shape[i])
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, arg1_shape};

NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
}

size_t count = nnfusion::shape_size(arg0_shape);
lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";

lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count
<< ", static_cast<const float*>(input0), 1, static_cast<const float*>(input1), 1, "
"static_cast<float*>(output0)));\n";
}
// check and calculate m for arg0 and out
size_t arg0_m_idx = 0; // first axe in arg0 for m
size_t out_m_idx = 0; // first axe in out for m
for (size_t i = 0; i < axes_for_m_count; i++)
// // matrix * vector
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1))
{
m *= arg0_shape[arg0_m_idx];
if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++])
{
std::vector<std::string> arg_vec{"arg0", "output"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, out_shape};
lu << "const float alpha = 1.0;\n const float beta = 0;\n";
lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, ";
if (trans_A)
lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", ";
else
lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", ";
lu << " &alpha,"
<< " static_cast<const float*>(input0)," << arg0_shape[1] << ", "
<< " static_cast<const float*>(input1),"
<< " 1,"
<< " &beta,"
<< " static_cast<float*>(output0),"
<< " 1));\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) &&
(trans_A || trans_B))
{
int m = trans_B ? arg1_shape[0] : arg1_shape[1];
int n = trans_A ? arg0_shape[1] : arg0_shape[0];
int k = trans_A ? arg0_shape[0] : arg0_shape[1];

NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
lu << "const half alpha = 1.0;\nconst half beta = 0;\n";

lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle,"
<< (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,")
<< (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << ","
<< " " << n << ","
<< " " << k << ","
<< " &alpha,"
<< " static_cast<const half*>(input1),"
<< " " << arg1_shape[1] << ","
<< " static_cast<const half*>(input0),"
<< " " << arg0_shape[1] << ","
<< " &beta,"
<< " static_cast<half*>(output0),"
<< " " << m << "));\n";
}
// check and calculate n for arg1 and out
size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n
size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n
for (size_t i = 0; i < axes_for_n_count; i++)
else
{
n *= arg1_shape[arg1_n_idx];
if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++])
size_t axes_for_m_count = arg0_shape.size() - reduction_axes;
size_t axes_for_n_count = arg1_shape.size() - reduction_axes;
size_t axes_for_k_count = reduction_axes;
size_t m = 1;
size_t n = 1;
size_t k = 1;

// check if input and output size correct
// check and calculate k for arg0 and arg1
size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k
size_t arg1_k_idx = 0; // first axe in arg1 for k

for (size_t i = 0; i < axes_for_k_count; i++)
{
std::vector<std::string> arg_vec{"arg1", "output"};
std::vector<nnfusion::Shape> shape_vec{arg1_shape, out_shape};
k *= arg0_shape[arg0_k_idx];
if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++])
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, arg1_shape};

NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
}
}
// check and calculate m for arg0 and out
size_t arg0_m_idx = 0; // first axe in arg0 for m
size_t out_m_idx = 0; // first axe in out for m
for (size_t i = 0; i < axes_for_m_count; i++)
{
m *= arg0_shape[arg0_m_idx];
if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++])
{
std::vector<std::string> arg_vec{"arg0", "output"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, out_shape};

NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
}
// check and calculate n for arg1 and out
size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n
size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n
for (size_t i = 0; i < axes_for_n_count; i++)
{
n *= arg1_shape[arg1_n_idx];
if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++])
{
std::vector<std::string> arg_vec{"arg1", "output"};
std::vector<nnfusion::Shape> shape_vec{arg1_shape, out_shape};

lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n";

lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle,"
<< " CUBLAS_OP_N,"
<< " CUBLAS_OP_N,"
<< " " << n << ","
<< " " << m << ","
<< " " << k << ","
<< " &alpha,"
<< " static_cast<const half*>(input1),"
<< " " << n << ","
<< " static_cast<const half*>(input0),"
<< " " << k << ","
<< " &beta,"
<< " static_cast<half*>(output0),"
<< " " << n << "));\n";
}
NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
}

lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n";

lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle,"
<< " CUBLAS_OP_N,"
<< " CUBLAS_OP_N,"
<< " " << n << ","
<< " " << m << ","
<< " " << k << ","
<< " &alpha,"
<< " static_cast<const half*>(input1),"
<< " " << n << ","
<< " static_cast<const half*>(input0),"
<< " " << k << ","
<< " &beta,"
<< " static_cast<half*>(output0),"
<< " " << n << "));\n";
}
}
else
{
Expand Down
6 changes: 4 additions & 2 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ for (int tidx = thread_idx; tidx < width; tidx += block_size) {
val = reduceSum(val, thread_idx, block_size, shm);
if (thread_idx == 0) output0[block_idx] = val;
)",
{{"width", width}, {"block_size", expected_block_size}, {"warp_size", 32},{"dataType", dtype==nnfusion::element::f16? "half" : "float"}});
{{"width", width},
{"block_size", expected_block_size},
{"warp_size", 32},
{"dataType", dtype == nnfusion::element::f16 ? "half" : "float"}});

lu << code << "\n";
return _lu;
Expand Down Expand Up @@ -582,7 +585,6 @@ if (thread_idx == 0) output0[block_idx] = val;
m_gridDim = dim3(1, 1, 1);
m_blockDim = dim3(block_size_x, 1, 1);
}

}
}

Expand Down
Loading