Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix default_kernel1 entry #418

Open
wants to merge 24 commits into
base: osdi22_artifact
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,21 @@ LanguageUnit_p
{
dimensions[pos++] = static_cast<int>(shape[i]);
}
// lu << "CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptorEx(" << desc << ", " << data_type << ", "
// << dimensions[0] << ", " << dimensions[1] << ", " << dimensions[2] << ", "
// << dimensions[3] << ", 1, 1, 1, 1));\n";
lu << "CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptorEx(" << desc << ", " << data_type << ", "
<< dimensions[0] << ", " << dimensions[1] << ", " << dimensions[2] << ", "
<< dimensions[3] << ", 1, 1, 1, 1));\n";
<< "1, " << dimensions[1] << ", 1, 1, 1, 1, 1, 1));\n";
}
else if (shape.size() == 4)
{
// lu << "CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptorEx(" << desc << ", " << data_type << ", "
// << static_cast<int>(shape[0]) << ", " << static_cast<int>(shape[1]) << ", "
// << static_cast<int>(shape[2]) << ", " << static_cast<int>(shape[3])
// << ", 1, 1, 1, 1));\n";

lu << "CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptorEx(" << desc << ", " << data_type << ", "
<< static_cast<int>(shape[0]) << ", " << static_cast<int>(shape[1]) << ", "
<< static_cast<int>(shape[2]) << ", " << static_cast<int>(shape[3])
<< ", 1, 1, 1, 1));\n";
<< "1, " << static_cast<int>(shape[1]) << ",1, 1, 1, 1, 1, 1));\n";
}

return _lu;
Expand Down
29 changes: 13 additions & 16 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/avg_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ cuda::AvgPoolmD::AvgPoolmD(shared_ptr<KernelContext> ctx)

LanguageUnit_p cuda::AvgPoolmD::emit_function_body()
{
if (input_shape.size() != 4 && input_shape.size() != 5)
return nullptr;

LanguageUnit_p _lu(new LanguageUnit(get_function_name()));
auto& lu = *_lu;
auto rank = input_shape.size();
Expand All @@ -297,8 +294,8 @@ LanguageUnit_p cuda::AvgPoolmD::emit_function_body()
window_shape.insert(window_shape.begin(), 1);
padding_below.insert(padding_below.begin(), 0);
window_stride.insert(window_stride.begin(), 1);
_input_shape.insert(_input_shape.begin() + 1, 1);
_output_shape.insert(_output_shape.begin() + 1, 1);
_input_shape.insert(_input_shape.begin() + 2, 1);
_output_shape.insert(_output_shape.begin() + 2, 1);
rank = 4;
}

Expand Down Expand Up @@ -511,19 +508,19 @@ LanguageUnit_p cuda::AvgPoolmDGrad::emit_function_body()
auto _d_input_shape = d_input_shape;
auto _output_shape = output_shape;
auto _d_output_shape = d_output_shape;

NNFUSION_LOG(INFO) << "---------4";
if (rank == 3)
{
window_shape.insert(window_shape.begin(), 1);
padding_below.insert(padding_below.begin(), 0);
window_stride.insert(window_stride.begin(), 1);
_input_shape.insert(_input_shape.begin() + 1, 1);
_output_shape.insert(_output_shape.begin() + 1, 1);
_d_input_shape.insert(_d_input_shape.begin() + 1, 1);
_d_output_shape.insert(_d_output_shape.begin() + 1, 1);
_input_shape.insert(_input_shape.begin() + 2, 1);
_output_shape.insert(_output_shape.begin() + 2, 1);
_d_input_shape.insert(_d_input_shape.begin() + 2, 1);
_d_output_shape.insert(_d_output_shape.begin() + 2, 1);
rank = 4;
}

NNFUSION_LOG(INFO) << "---------5";
// y dy x dx
auto input_desc = cudnn_tensor_descriptor_from_shape(_input_shape, "input_desc", input_type);
auto d_input_desc =
Expand Down Expand Up @@ -605,14 +602,14 @@ LanguageUnit_p cuda::AvgPoolmDGrad::emit_function_body()
lu << "CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(output_desc));\n";
lu << "CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(d_output_desc));\n";
lu << "CUDNN_SAFE_CALL(cudnnDestroyPoolingDescriptor(desc));\n";

NNFUSION_LOG(INFO) << "---------6";
return _lu;
}

REGISTER_KERNEL_EMITTER(
"AvgPool", // op_name
Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs
cuda::AvgPool1D) // constructor
// REGISTER_KERNEL_EMITTER(
// "AvgPool", // op_name
// Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs
// cuda::AvgPool1D) // constructor

REGISTER_KERNEL_EMITTER(
"AvgPool", // op_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ cuda::DepthwiseConv2dNative::DepthwiseConv2dNative(shared_ptr<KernelContext> ctx

const Shape input_shape = Shape(ctx->inputs[0]->get_shape());
// [ filter_rows, filter_cols, in_depth, depth_multiplier]
const Shape filter_shape = Shape(ctx->inputs[1]->get_shape());
// const Shape filter_shape = Shape(ctx->inputs[1]->get_shape());
// ad_hoc
const Shape filter_shape_ori = Shape(ctx->inputs[1]->get_shape());
Shape filter_shape =
Shape{filter_shape_ori[2], filter_shape_ori[3], filter_shape_ori[0], filter_shape_ori[1]};
const Shape output_shape = Shape(ctx->outputs[0]->get_shape());

data_format = op->localOpConfig.getRoot()["data_format"];
Expand Down
26 changes: 21 additions & 5 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/elementwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,19 @@ namespace nnfusion

{
std::string tid =
"blockIdx.x * " + std::to_string(blocks) + " + threadIdx.x";
"blockIdx.x * " + std::to_string(blocks) + " * 2 + threadIdx.x";
std::string tid1 =
"blockIdx.x * " + std::to_string(blocks) + " * 2 + threadIdx.x + 128";
if (grids == 1)
{
tid = "threadIdx.x";
tid1 = "threadIdx.x + 128";
}
if (bound)
{
lu << "if (" << tid << " >= " << bound << ") return;";

lu << "if (" << tid1 << " >= " << bound << ") return;";
}
{
std::string invoke_func = op;
if (m_context->gnode->get_op_type() == "Convert")
Expand All @@ -89,6 +96,13 @@ namespace nnfusion
lu << "input" << i << "[" << tid << "], ";
}
lu << "input" << num_inputs - 1 << "[" << tid << "]);\n";

lu << "output0[" << tid1 << "] = " << invoke_func << "(";
for (size_t i = 0; i < num_inputs - 1; i++)
{
lu << "input" << i << "[" << tid1 << "], ";
}
lu << "input" << num_inputs - 1 << "[" << tid1 << "]);\n";
}
}
return lu_;
Expand Down Expand Up @@ -134,26 +148,28 @@ namespace nnfusion
{
uint32_t num_ele = static_cast<uint32_t>(
nnfusion::shape_size(m_context->outputs[0]->get_shape()));
for (int i = 512; i >= 64; i >>= 1)
for (int i = 128; i >= 64; i >>= 1)
{
if (num_ele % i == 0)
{
grids = num_ele / i, blocks = i, bound = 0;
grids = grids / 2;
return;
}
}
for (int i = 512; i >= 32; i--)
for (int i = 128; i >= 32; i--)
{
if (num_ele % i == 0)
{
grids = num_ele / i, blocks = i, bound = 0;
grids = grids / 2;
return;
}
}
if (num_ele < 32)
grids = 1, blocks = num_ele, bound = 0;
else
grids = (num_ele + 255) / 256, blocks = 256, bound = 1;
grids = (num_ele + 255) / 256, blocks = 128, bound = 1;
}

// shared_ptr<KernelContext> kernel_ctx;
Expand Down
41 changes: 35 additions & 6 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/elementwise_fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,35 @@ LanguageUnit_p ElementWiseFused::emit_function_body()
if (grids == 1)
{
lu << "int tid = threadIdx.x;\n";
lu << "int tid1 = threadIdx.x + 128;\n";
}
else
{
lu << "int tid = blockIdx.x * " << std::to_string(blocks) << " + threadIdx.x;\n";
lu << "int tid = blockIdx.x * " << std::to_string(blocks) << " * 2 + threadIdx.x;\n";
lu << "int tid1 = blockIdx.x * " << std::to_string(blocks) << "* 2 + threadIdx.x + 128;\n";
}
if (bound)
{
lu << "if (tid >= " << bound << ") return;\n";
lu << "if (tid1 >= " << bound << ") return;\n";
}

for (auto op_ctx : m_gnode->get_op_contexts())
{
auto& out_tw = op_ctx->outputs[0];
if (auto bc = std::dynamic_pointer_cast<nnfusion::op::Broadcast>(op_ctx->op))
{
std::string index = "";
std::string index = "", index1 = "";
if (bc->is_inner_broadcast())
{
index += "[tid / " + std::to_string(bc->get_inner_broadcast_size()) + "]";
index1 += "[tid1 / " + std::to_string(bc->get_inner_broadcast_size()) + "]";
}
else
{
NNFUSION_CHECK(bc->is_outer_broadcast());
index += "[tid % " + std::to_string(bc->get_outer_broadcast_size()) + "]";
index1 += "[tid1 % " + std::to_string(bc->get_outer_broadcast_size()) + "]";
}
local_tensors[out_tw->get_name()] = "temp" + std::to_string(temp_tensor_id++);
auto& in_tw = op_ctx->inputs[0];
Expand All @@ -104,6 +109,9 @@ LanguageUnit_p ElementWiseFused::emit_function_body()
lu << out_tw->get_element_type().c_type_string() << " "
<< local_tensors[out_tw->get_name()] << " = " << in_args[in_tw->get_name()] << index
<< ";\n";
lu << out_tw->get_element_type().c_type_string() << " "
<< local_tensors[out_tw->get_name()] << "_1 = " << in_args[in_tw->get_name()]
<< index1 << ";\n";
}
else if (auto rs = std::dynamic_pointer_cast<nnfusion::op::Reshape>(op_ctx->op))
{
Expand Down Expand Up @@ -150,23 +158,29 @@ LanguageUnit_p ElementWiseFused::emit_function_body()
invoke_func = op_kernel.first;
}
local_tensors[out_tw->get_name()] = "temp" + std::to_string(temp_tensor_id++);
std::vector<std::string> input_args;
std::vector<std::string> input_args, input_args1;
for (int i = 0; i < op_ctx->inputs.size(); i++)
{
auto& in_tw = op_ctx->inputs[i];
if (in_args.count(in_tw->get_name()) > 0)
{
input_args.push_back(in_args[in_tw->get_name()] + "[tid]");
input_args1.push_back(in_args[in_tw->get_name()] + "[tid1]");
}
else
{
NNFUSION_CHECK(local_tensors.count(in_tw->get_name()) > 0);
input_args.push_back(local_tensors[in_tw->get_name()]);
input_args1.push_back(local_tensors[in_tw->get_name()] + "_1");
}
}
lu << out_tw->get_element_type().c_type_string() << " "
<< local_tensors[out_tw->get_name()] << " = " << invoke_func << "("
<< join(input_args, ", ") << ");\n";

lu << out_tw->get_element_type().c_type_string() << " "
<< local_tensors[out_tw->get_name()] << "_1 = " << invoke_func << "("
<< join(input_args1, ", ") << ");\n";
}
}

Expand All @@ -183,6 +197,19 @@ LanguageUnit_p ElementWiseFused::emit_function_body()
<< lu.get_code() << " " << pair.first;
lu << in_args[pair.first] << "[tid];\n";
}

lu << pair.second << "[tid1] = ";
if (local_tensors.count(pair.first) > 0)
{
lu << local_tensors[pair.first] << "_1;\n";
}
else
{
NNFUSION_CHECK(in_args.count(pair.first) > 0) << m_context->gnode->get_name() << " "
<< lu.get_code() << " " << pair.first
<< "_1";
lu << in_args[pair.first] << "_1[tid1];\n";
}
}

return lu_;
Expand Down Expand Up @@ -261,26 +288,28 @@ void ElementWiseFused::compute_best_config(int& grids, int& blocks, int& bound)
{
uint32_t num_ele =
static_cast<uint32_t>(nnfusion::shape_size(m_context->outputs[0]->get_shape()));
for (int i = 512; i >= 64; i >>= 1)
for (int i = 128; i >= 64; i >>= 1)
{
if (num_ele % i == 0)
{
grids = num_ele / i, blocks = i, bound = 0;
grids = grids / 2;
return;
}
}
for (int i = 512; i >= 32; i--)
for (int i = 128; i >= 32; i--)
{
if (num_ele % i == 0)
{
grids = num_ele / i, blocks = i, bound = 0;
grids = grids / 2;
return;
}
}
if (num_ele < 32)
grids = 1, blocks = num_ele, bound = 0;
else
grids = (num_ele + 255) / 256, blocks = 256, bound = 1;
grids = (num_ele + 255) / 256, blocks = 128, bound = 1;
}

REGISTER_KERNEL_EMITTER(
Expand Down
31 changes: 29 additions & 2 deletions src/nnfusion/core/kernels/kernel_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "kernel_emitter.hpp"
#include "nnfusion/core/operators/generic_op/generic_op.hpp"
#include "nnfusion/engine/async_manager.hpp"

#include <string>
Expand Down Expand Up @@ -473,7 +474,7 @@ std::string nnfusion::kernels::KernelContext::generate_identifier()
str << avgpool->get_window_shape();
str << avgpool->get_window_movement_strides();
str << avgpool->get_padding_below();
str << avgpool->get_padding_above();
// str << avgpool->get_padding_above();
identifier += str.str();
}
else if (op_type == "MaxPool")
Expand All @@ -484,7 +485,7 @@ std::string nnfusion::kernels::KernelContext::generate_identifier()
str << maxpool->get_window_shape();
str << maxpool->get_window_movement_strides();
str << maxpool->get_padding_below();
str << maxpool->get_padding_above();
// str << maxpool->get_padding_above();
identifier += str.str();
}
else if (op_type == "Dot")
Expand All @@ -498,6 +499,32 @@ std::string nnfusion::kernels::KernelContext::generate_identifier()
// ///\todo: need to encode dot reduction_axes_count?
// identifier += str.str();
}
else if (op_type == "Sum")
{
auto op = std::static_pointer_cast<op::Sum>(ctx->gnode->get_op_ptr());
NNFUSION_CHECK_NOT_NULLPTR(op);
std::stringstream str;
str << op->get_reduction_axes();
identifier += str.str();
}
else if (op_type == "Broadcast")
{
auto op = std::static_pointer_cast<op::Broadcast>(ctx->gnode->get_op_ptr());
NNFUSION_CHECK_NOT_NULLPTR(op);
std::stringstream str;
str << op->get_broadcast_axes();
identifier += str.str();
}
else if (op_type == "DepthwiseConv2dNative")
{
auto op = std::static_pointer_cast<op::GenericOp>(ctx->gnode->get_op_ptr());
NNFUSION_CHECK_NOT_NULLPTR(op);
std::stringstream str;
str << op->localOpConfig.getRoot()["strides"];
str << op->localOpConfig.getRoot()["dilations"];
str << op->localOpConfig.getRoot()["padding_before"];
identifier += str.str();
}

return identifier;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ REGISTER_OP(DepthwiseConv2dNative)
const Shape& input_shape = gnode->get_input_shape(0);

// [ filter_rows, filter_cols, in_depth, depth_multiplier ]
const Shape& filter_shape = gnode->get_input_shape(1);
// const Shape& filter_shape = gnode->get_input_shape(1);

// ad_hoc: [ in_depth, depth_multiplier, filter_rows, filter_cols ]
const Shape& filter_shape_ts = gnode->get_input_shape(1);
nnfusion::Shape filter_shape{
filter_shape_ts[2], filter_shape_ts[3], filter_shape_ts[0], filter_shape_ts[1]};

std::string data_format = op->localOpConfig.getRoot()["data_format"];
bool is_nhwc = (data_format == "NHWC");
Expand Down
Loading