From 3a555659686845253b77033eff76c5a81fc2a3ea Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Wed, 20 Jul 2022 14:31:27 +0800 Subject: [PATCH 1/4] Support Conv3D in ONNX frontend and AntaresIR --- .../generic_op_define/Convolution.cpp | 174 +++++++++++++----- .../core/operators/util/validation_util.cpp | 35 ++-- src/nnfusion/frontend/onnx_import/op/conv.cpp | 14 +- 3 files changed, 160 insertions(+), 63 deletions(-) diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp index 5c7e0e907..6d2b07dbd 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp @@ -30,55 +30,139 @@ REGISTER_OP(Convolution) }) */ .translate_v2([](std::shared_ptr curr) -> std::string { - auto ir_template = - R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )"; - auto _op = static_pointer_cast(curr->get_op_ptr()); NNFUSION_CHECK_NOT_NULLPTR(_op) << "Node type is not " << curr->get_op_ptr()->get_op_type(); - const auto& dilation_h = _op->get_window_dilation_strides()[0]; - const auto& dilation_w = _op->get_window_dilation_strides()[1]; - const auto& stride_h = _op->get_window_movement_strides()[0]; - const auto& stride_w = _op->get_window_movement_strides()[1]; - const auto& is_nchw = _op->get_data_format() == "NCHW"; - const auto& padding_below = _op->get_padding_below(); - const auto& padding_above = _op->get_padding_above(); - const auto& padding_h = _op->get_padding_below()[0]; - const auto& padding_w = _op->get_padding_below()[1]; - const auto& kernel_size_h = - is_nchw ? curr->get_input_shape(1)[2] : curr->get_input_shape(1)[0]; - const auto& kernel_size_w = - is_nchw ? curr->get_input_shape(1)[3] : curr->get_input_shape(1)[1]; - const auto& in_shape = curr->get_input_shape(0); - const auto& out_shape = curr->get_output_shape(0); - const std::string data_format = is_nchw ? "nchw" : "nhwc"; - NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet."; - NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet."; - NNFUSION_CHECK(padding_below == padding_above) - << "Asymetric padding is not supported by now."; - nnfusion::op::OpConfig::any config; - std::string HO = "-@pad_0@ + KH + HO * " + to_string(stride_h); - std::string WO = "-@pad_1@ + KW + WO * " + to_string(stride_w); - std::string shape_template = - is_nchw ? "[N, C, " + HO + ", " + WO + "]" : "[N, " + HO + ", " + WO + ", C]"; - config["input1_layout"] = is_nchw ? "[F, C, KH, KW]" : "[KH, KW, C, F]"; - config["output0_layout"] = is_nchw ? "[N, F, HO, WO]" : "[N, HO, WO, F]"; - config["height"] = is_nchw ? out_shape[2] : out_shape[1]; - config["width"] = is_nchw ? out_shape[3] : out_shape[2]; - config["pad_0"] = to_string(padding_h); - config["pad_1"] = to_string(padding_w); - config["input0_layout"] = op::create_code_from_template(shape_template, config); - std::string pad_cond; - if (padding_h || padding_w) + if (_op->get_data_format() == "NCHW" || _op->get_data_format() == "NHWC") // Conv2D { - config["in_height"] = is_nchw ? in_shape[2] : in_shape[1]; - config["in_width"] = is_nchw ? in_shape[3] : in_shape[2]; - auto pad_template = ".when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO + - " >= 0, " + WO + - " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))"; - pad_cond = op::create_code_from_template(pad_template, config); + auto ir_template = + R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )"; + + const auto& dilation_h = _op->get_window_dilation_strides()[0]; + const auto& dilation_w = _op->get_window_dilation_strides()[1]; + const auto& stride_h = _op->get_window_movement_strides()[0]; + const auto& stride_w = _op->get_window_movement_strides()[1]; + const auto& is_nchw = _op->get_data_format() == "NCHW"; + const auto& padding_below = _op->get_padding_below(); + const auto& padding_above = _op->get_padding_above(); + const auto& padding_h = _op->get_padding_below()[0]; + const auto& padding_w = _op->get_padding_below()[1]; + const auto& kernel_size_h = + is_nchw ? curr->get_input_shape(1)[2] : curr->get_input_shape(1)[0]; + const auto& kernel_size_w = + is_nchw ? curr->get_input_shape(1)[3] : curr->get_input_shape(1)[1]; + const auto& in_shape = curr->get_input_shape(0); + const auto& out_shape = curr->get_output_shape(0); + const std::string data_format = is_nchw ? "nchw" : "nhwc"; + if (dilation_h != 1 || dilation_w != 1) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Not support other dilation yet."; + return ""; + } + if (padding_below != padding_above) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Asymetric padding is not supported by now."; + return ""; + } + // NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(padding_below == padding_above) + // << "Asymetric padding is not supported by now."; + nnfusion::op::OpConfig::any config; + std::string HO = "-@pad_0@ + KH + HO * " + to_string(stride_h); + std::string WO = "-@pad_1@ + KW + WO * " + to_string(stride_w); + std::string shape_template = + is_nchw ? "[N, C, " + HO + ", " + WO + "]" : "[N, " + HO + ", " + WO + ", C]"; + config["input1_layout"] = is_nchw ? "[F, C, KH, KW]" : "[KH, KW, C, F]"; + config["output0_layout"] = is_nchw ? "[N, F, HO, WO]" : "[N, HO, WO, F]"; + config["height"] = is_nchw ? out_shape[2] : out_shape[1]; + config["width"] = is_nchw ? out_shape[3] : out_shape[2]; + config["pad_0"] = to_string(padding_h); + config["pad_1"] = to_string(padding_w); + config["input0_layout"] = op::create_code_from_template(shape_template, config); + + std::string pad_cond; + if (padding_h || padding_w) + { + config["in_height"] = is_nchw ? in_shape[2] : in_shape[1]; + config["in_width"] = is_nchw ? in_shape[3] : in_shape[2]; + auto pad_template = + ".when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO + " >= 0, " + WO + + " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))"; + pad_cond = op::create_code_from_template(pad_template, config); + } + config["pad_cond"] = pad_cond; + + return op::create_code_from_template(ir_template, config); + } + else if (_op->get_data_format() == "NCDHW") // Conv3D + { + auto ir_template = + R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where DO in @depth@, HO in @height@, WO in @width@; )"; + + const auto& dilation_d = _op->get_window_dilation_strides()[0]; + const auto& dilation_h = _op->get_window_dilation_strides()[1]; + const auto& dilation_w = _op->get_window_dilation_strides()[2]; + const auto& stride_d = _op->get_window_movement_strides()[0]; + const auto& stride_h = _op->get_window_movement_strides()[1]; + const auto& stride_w = _op->get_window_movement_strides()[2]; + const auto& padding_below = _op->get_padding_below(); + const auto& padding_above = _op->get_padding_above(); + const auto& padding_d = _op->get_padding_below()[0]; + const auto& padding_h = _op->get_padding_below()[1]; + const auto& padding_w = _op->get_padding_below()[2]; + const auto& kernel_size_d = curr->get_input_shape(1)[2]; + const auto& kernel_size_h = curr->get_input_shape(1)[3]; + const auto& kernel_size_w = curr->get_input_shape(1)[4]; + const auto& in_shape = curr->get_input_shape(0); + const auto& out_shape = curr->get_output_shape(0); + const std::string data_format = "NCDHW"; + if (dilation_d != 1 || dilation_h != 1 || dilation_w != 1) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Not support other dilation yet."; + return ""; + } + if (padding_below != padding_above) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Asymetric padding is not supported by now."; + return ""; + } + // NNFUSION_CHECK(dilation_d == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(padding_below == padding_above) + // << "Asymetric padding is not supported by now."; + nnfusion::op::OpConfig::any config; + std::string DO = "-@pad_0@ + KD + DO * " + to_string(stride_d); + std::string HO = "-@pad_1@ + KH + HO * " + to_string(stride_h); + std::string WO = "-@pad_2@ + KW + WO * " + to_string(stride_w); + std::string shape_template = "[N, C, " + DO + ", " + HO + ", " + WO + "]"; + config["input1_layout"] = "[F, C, KD, KH, KW]"; + config["output0_layout"] = "[N, F, DO, HO, WO]"; + config["depth"] = out_shape[2]; + config["height"] = out_shape[3]; + config["width"] = out_shape[4]; + config["pad_0"] = to_string(padding_d); + config["pad_1"] = to_string(padding_h); + config["pad_2"] = to_string(padding_w); + config["input0_layout"] = op::create_code_from_template(shape_template, config); + + std::string pad_cond; + if (padding_d || padding_h || padding_w) + { + config["in_depth"] = in_shape[2]; + config["in_height"] = in_shape[3]; + config["in_width"] = in_shape[4]; + auto pad_template = + ".when([" + DO + " >= 0, " + DO + " < @in_depth@, " + HO + " >= 0, " + HO + + " < @in_height@, " + WO + " >= 0, " + WO + + " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))"; + pad_cond = op::create_code_from_template(pad_template, config); + } + config["pad_cond"] = pad_cond; + + return op::create_code_from_template(ir_template, config); } - config["pad_cond"] = pad_cond; - return op::create_code_from_template(ir_template, config); + return ""; }); diff --git a/src/nnfusion/core/operators/util/validation_util.cpp b/src/nnfusion/core/operators/util/validation_util.cpp index 1cbaa7456..758810682 100644 --- a/src/nnfusion/core/operators/util/validation_util.cpp +++ b/src/nnfusion/core/operators/util/validation_util.cpp @@ -176,26 +176,33 @@ std::tuple << "), padding above (" << data_padding_above << "), filter strides (" << filter_strides << "), and filter dilation (" << filter_dilation << ") do not match."; - OP_VALIDATION(op, data_format == "NCW" || data_format == "NCHW" || data_format == "NHWC") - << "data format must be Conv1D: NCW, Conv2D: NCHW or NHWC."; + OP_VALIDATION(op, + data_format == "NCW" || data_format == "NCHW" || data_format == "NHWC" || + data_format == "NCDHW") + << "data format must be Conv1D: NCW, Conv2D: NCHW or NHWC, Conv3D: NCDHW."; nnfusion::Dimension batch_size = (data_batch_shape.rank().is_static() ? data_batch_shape[0] : nnfusion::Dimension::dynamic()); nnfusion::Dimension data_channel_count = (data_batch_shape.rank().is_static() - ? (data_format == "NCW" || data_format == "NCHW") ? data_batch_shape[1] - : data_batch_shape[3] + ? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? data_batch_shape[1] + : data_batch_shape[3] : nnfusion::Dimension::dynamic()); nnfusion::PartialShape data_spatial_shape(nnfusion::PartialShape::dynamic(spatial_rank)); nnfusion::Dimension filter_output_channel_count = (filters_shape.rank().is_static() - ? (data_format == "NCW" || data_format == "NCHW") ? filters_shape[0] : filters_shape[3] + ? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? filters_shape[0] + : filters_shape[3] : nnfusion::Dimension::dynamic()); nnfusion::Dimension filter_input_channel_count = (filters_shape.rank().is_static() - ? (data_format == "NCW" || data_format == "NCHW") ? filters_shape[1] : filters_shape[2] + ? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? filters_shape[1] + : filters_shape[2] : nnfusion::Dimension::dynamic()); nnfusion::PartialShape filter_spatial_shape(nnfusion::PartialShape::dynamic(spatial_rank)); @@ -207,16 +214,18 @@ std::tuple { if (data_batch_shape.rank().is_static()) { - data_spatial_shape[i] = (data_format == "NCW" || data_format == "NCHW") - ? data_batch_shape[i + 2] - : data_batch_shape[i + 1]; + data_spatial_shape[i] = + (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? data_batch_shape[i + 2] + : data_batch_shape[i + 1]; } if (filters_shape.rank().is_static()) { - filter_spatial_shape[i] = (data_format == "NCW" || data_format == "NCHW") - ? filters_shape[i + 2] - : filters_shape[i]; + filter_spatial_shape[i] = + (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? filters_shape[i + 2] + : filters_shape[i]; } } @@ -253,7 +262,7 @@ std::tuple nnfusion::PartialShape batch_output_shape(nnfusion::PartialShape::dynamic(spatial_rank + 2)); - if (data_format == "NCW" || data_format == "NCHW") + if (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") { batch_output_shape[0] = batch_size; batch_output_shape[1] = filter_output_channel_count; diff --git a/src/nnfusion/frontend/onnx_import/op/conv.cpp b/src/nnfusion/frontend/onnx_import/op/conv.cpp index cb98dfd1d..a57341ebf 100644 --- a/src/nnfusion/frontend/onnx_import/op/conv.cpp +++ b/src/nnfusion/frontend/onnx_import/op/conv.cpp @@ -78,10 +78,10 @@ namespace nnfusion { conv_data_format = "NCHW"; } - // else if (data_shape.size() == 5) - // { - // conv_data_format = "NCDHW"; - // } + else if (data_shape.size() == 5) + { + conv_data_format = "NCDHW"; + } else { NNFUSION_CHECK_FAIL() << "Convolution with dimensions of " @@ -168,7 +168,7 @@ namespace nnfusion strides, dilations, padding_below, padding_above, conv_data_format); conv_node = m_graph->add_node_and_edge(conv_op, {data, filters}); } - else + else if (conv_data_format == "NCHW") { // split data and filters for group conv std::size_t n_data_channels{data_shape.at(1)}; @@ -264,6 +264,10 @@ namespace nnfusion convolution_nodes); } } + else + { + NNFUSION_CHECK_FAIL() << "Not support this Convolution yet."; + } // add bias if (input_indexes.size() == 3) From 68978c336dce755f5d09469912b698a586637fa1 Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Wed, 20 Jul 2022 14:43:25 +0800 Subject: [PATCH 2/4] format code style --- .../core/kernels/cuda_gpu/cuda_cudnn.cpp | 7 +- .../kernels/cuda_gpu/kernels/batch_matmul.cpp | 9 +- .../kernels/cuda_gpu/kernels/batch_norm.cpp | 10 +- .../core/kernels/cuda_gpu/kernels/dot.cpp | 280 +++++++++--------- .../core/kernels/cuda_gpu/kernels/reduce.hpp | 6 +- .../generic_op/generic_op_define/Slice.cpp | 10 +- .../engine/pass/codegen/cuda_codegen_pass.cpp | 2 +- .../engine/pass/extract_graph_signature.cpp | 6 +- src/nnfusion/frontend/onnx_import/op/pad.hpp | 32 +- .../frontend/onnx_import/util/util.hpp | 15 +- 10 files changed, 201 insertions(+), 176 deletions(-) diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp index 040349588..5fca1dd00 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp @@ -201,11 +201,12 @@ LanguageUnit_p cuda::get_cudnn_convolution_descriptor(const Shape& padding, << "window_dilation_strides_int, CUDNN_CROSS_CORRELATION, " << data_type << "));\n"; } - if(type == nnfusion::element::f16){ + if (type == nnfusion::element::f16) + { // half precision, use tensor core lu << "CUDNN_SAFE_CALL(cudnnSetConvolutionMathType(" << desc << ", " - << "CUDNN_TENSOR_OP_MATH" - << "));\n"; + << "CUDNN_TENSOR_OP_MATH" + << "));\n"; } return _lu; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp index dff602dc3..1741dbf40 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp @@ -103,8 +103,9 @@ namespace nnfusion @hCublas@, @transA@, @transB@, @m@, @n@, @k@, &alpha, input1, @lda@, @stride_a@, input0, @ldb@, @stride_b@, &beta, output0, @ldc@, @stride_c@, @batch@)); - )" : - R"( + )" + : + R"( static const float alpha = @alpha@F, beta = @beta@F; // if (!@hCublas@) // CUBLAS_SAFE_CALL(@api_create@(&@hCublas@)); @@ -116,7 +117,9 @@ namespace nnfusion { {"hCublas", "cublas_handle"}, {"api_create", "cublasCreate"}, - {"api_exec", dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" : "cublasSgemmStridedBatched"}, + {"api_exec", + dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" + : "cublasSgemmStridedBatched"}, {"transA", transB ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"transB", transA ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"alpha", alpha}, diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp index 42404ac7b..a6c66dfdc 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp @@ -171,10 +171,14 @@ LanguageUnit_p cuda::BatchNormNCHW::emit_function_body() /* * todo: may have better solution, details in https://github.com/microsoft/nnfusion/issues/434 * */ - if(dtype == nnfusion::element::f16){ - lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], __hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(" + if (dtype == nnfusion::element::f16) + { + lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], " + "__hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(" << epsilon << "), input4[c_id]))));\n"; - }else{ + } + else + { lu << "(input1[c_id] + (input0[c_id] * " "(input2[st + i] - input3[c_id]) / sqrtf(" << epsilon << " + input4[c_id])));\n"; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp index 8e7b7a735..a78071f18 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp @@ -207,161 +207,165 @@ LanguageUnit_p cuda::Dot::emit_function_body() else if (dtype == element::f16) { // case 1: Scalar * Tensor - if (arg0_shape.empty() || arg1_shape.empty()) - { - auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); - size_t count = nnfusion::shape_size(second); + if (arg0_shape.empty() || arg1_shape.empty()) + { + auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); + size_t count = nnfusion::shape_size(second); - string firstarg = (arg0_shape.empty() ? "input1" : "input0"); - string secondarg = (arg0_shape.empty() ? "input0" : "input1"); + string firstarg = (arg0_shape.empty() ? "input1" : "input0"); + string secondarg = (arg0_shape.empty() ? "input0" : "input1"); - lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; + lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` - lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n"; - } + lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count + << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` + lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count + << "));\n"; + } // // case 2: 1d Dot - else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) - { - for (int i = 0; i < arg0_shape.size(); i++) - { - if (arg0_shape[i] != arg1_shape[i]) - { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; - - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } - } - - size_t count = nnfusion::shape_size(arg0_shape); - lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - - lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count - << ", static_cast(input0), 1, static_cast(input1), 1, " - "static_cast(output0)));\n"; - } - // // matrix * vector - else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) - { - lu << "const float alpha = 1.0;\n const float beta = 0;\n"; - lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; - if (trans_A) - lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; - else - lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; - lu << " &alpha," - << " static_cast(input0)," << arg0_shape[1] << ", " - << " static_cast(input1)," - << " 1," - << " &beta," - << " static_cast(output0)," - << " 1));\n"; - } - else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && - (trans_A || trans_B)) - { - int m = trans_B ? arg1_shape[0] : arg1_shape[1]; - int n = trans_A ? arg0_shape[1] : arg0_shape[0]; - int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - - lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") - << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," - << " " << n << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << arg1_shape[1] << "," - << " static_cast(input0)," - << " " << arg0_shape[1] << "," - << " &beta," - << " static_cast(output0)," - << " " << m << "));\n"; - } else { - size_t axes_for_m_count = arg0_shape.size() - reduction_axes; - size_t axes_for_n_count = arg1_shape.size() - reduction_axes; - size_t axes_for_k_count = reduction_axes; - size_t m = 1; - size_t n = 1; - size_t k = 1; - - // check if input and output size correct - // check and calculate k for arg0 and arg1 - size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k - size_t arg1_k_idx = 0; // first axe in arg1 for k - - for (size_t i = 0; i < axes_for_k_count; i++) + else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) { - k *= arg0_shape[arg0_k_idx]; - if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + for (int i = 0; i < arg0_shape.size(); i++) { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; + if (arg0_shape[i] != arg1_shape[i]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } + + size_t count = nnfusion::shape_size(arg0_shape); + lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; + + lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count + << ", static_cast(input0), 1, static_cast(input1), 1, " + "static_cast(output0)));\n"; } - // check and calculate m for arg0 and out - size_t arg0_m_idx = 0; // first axe in arg0 for m - size_t out_m_idx = 0; // first axe in out for m - for (size_t i = 0; i < axes_for_m_count; i++) + // // matrix * vector + else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) { - m *= arg0_shape[arg0_m_idx]; - if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) - { - std::vector arg_vec{"arg0", "output"}; - std::vector shape_vec{arg0_shape, out_shape}; + lu << "const float alpha = 1.0;\n const float beta = 0;\n"; + lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; + if (trans_A) + lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; + else + lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; + lu << " &alpha," + << " static_cast(input0)," << arg0_shape[1] << ", " + << " static_cast(input1)," + << " 1," + << " &beta," + << " static_cast(output0)," + << " 1));\n"; + } + else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && + (trans_A || trans_B)) + { + int m = trans_B ? arg1_shape[0] : arg1_shape[1]; + int n = trans_A ? arg0_shape[1] : arg0_shape[0]; + int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } + lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") + << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," + << " " << n << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << arg1_shape[1] << "," + << " static_cast(input0)," + << " " << arg0_shape[1] << "," + << " &beta," + << " static_cast(output0)," + << " " << m << "));\n"; } - // check and calculate n for arg1 and out - size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n - size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n - for (size_t i = 0; i < axes_for_n_count; i++) + else { - n *= arg1_shape[arg1_n_idx]; - if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + size_t axes_for_m_count = arg0_shape.size() - reduction_axes; + size_t axes_for_n_count = arg1_shape.size() - reduction_axes; + size_t axes_for_k_count = reduction_axes; + size_t m = 1; + size_t n = 1; + size_t k = 1; + + // check if input and output size correct + // check and calculate k for arg0 and arg1 + size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k + size_t arg1_k_idx = 0; // first axe in arg1 for k + + for (size_t i = 0; i < axes_for_k_count; i++) { - std::vector arg_vec{"arg1", "output"}; - std::vector shape_vec{arg1_shape, out_shape}; + k *= arg0_shape[arg0_k_idx]; + if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } - } + // check and calculate m for arg0 and out + size_t arg0_m_idx = 0; // first axe in arg0 for m + size_t out_m_idx = 0; // first axe in out for m + for (size_t i = 0; i < axes_for_m_count; i++) + { + m *= arg0_shape[arg0_m_idx]; + if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) + { + std::vector arg_vec{"arg0", "output"}; + std::vector shape_vec{arg0_shape, out_shape}; + + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + // check and calculate n for arg1 and out + size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n + size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n + for (size_t i = 0; i < axes_for_n_count; i++) + { + n *= arg1_shape[arg1_n_idx]; + if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + { + std::vector arg_vec{"arg1", "output"}; + std::vector shape_vec{arg1_shape, out_shape}; - lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << " CUBLAS_OP_N," - << " CUBLAS_OP_N," - << " " << n << "," - << " " << m << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << n << "," - << " static_cast(input0)," - << " " << k << "," - << " &beta," - << " static_cast(output0)," - << " " << n << "));\n"; - } + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + + lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << " CUBLAS_OP_N," + << " CUBLAS_OP_N," + << " " << n << "," + << " " << m << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << n << "," + << " static_cast(input0)," + << " " << k << "," + << " &beta," + << " static_cast(output0)," + << " " << n << "));\n"; + } } else { diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp index 5c9146afb..4c47ba346 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp @@ -230,7 +230,10 @@ for (int tidx = thread_idx; tidx < width; tidx += block_size) { val = reduceSum(val, thread_idx, block_size, shm); if (thread_idx == 0) output0[block_idx] = val; )", - {{"width", width}, {"block_size", expected_block_size}, {"warp_size", 32},{"dataType", dtype==nnfusion::element::f16? "half" : "float"}}); + {{"width", width}, + {"block_size", expected_block_size}, + {"warp_size", 32}, + {"dataType", dtype == nnfusion::element::f16 ? "half" : "float"}}); lu << code << "\n"; return _lu; @@ -582,7 +585,6 @@ if (thread_idx == 0) output0[block_idx] = val; m_gridDim = dim3(1, 1, 1); m_blockDim = dim3(block_size_x, 1, 1); } - } } diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp index 75f02a837..4bfc4fa03 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp @@ -36,10 +36,12 @@ REGISTER_OP(Slice) auto step = steps[d]; auto start = starts[d]; auto end = ends[d]; - auto range = (u_int64_t)ceil((double)(end-start)/(double)step); - input_layout.push_back((step == 1? output_layout[d] : output_layout[d] + " * " + to_string(step)) + " + " + to_string(start)); - slice_dims += (slice_dims.empty() ? "" : " , ") + output_layout[d] + - " in " + to_string(range); + auto range = (u_int64_t)ceil((double)(end - start) / (double)step); + input_layout.push_back( + (step == 1 ? output_layout[d] : output_layout[d] + " * " + to_string(step)) + + " + " + to_string(start)); + slice_dims += + (slice_dims.empty() ? "" : " , ") + output_layout[d] + " in " + to_string(range); } auto expression_code = op::create_code_from_template( diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp index 2537142bf..6f520d104 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp @@ -1371,7 +1371,7 @@ cmake_minimum_required(VERSION 3.5) SET(SRC "nnfusion_rt.cu" CACHE STRING "codegen source file") SET(TARGET_NAME "nnfusion_naive_rt" CACHE STRING "codegen target name") -SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" CACHE STRING "target architecture") +SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86" CACHE STRING "target architecture") if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) diff --git a/src/nnfusion/engine/pass/extract_graph_signature.cpp b/src/nnfusion/engine/pass/extract_graph_signature.cpp index ce0e200bc..f537121f8 100644 --- a/src/nnfusion/engine/pass/extract_graph_signature.cpp +++ b/src/nnfusion/engine/pass/extract_graph_signature.cpp @@ -142,7 +142,8 @@ bool ExtractGraphSignature::extract_args(std::shared_ptr ctx const element::Type& et = tv->get_element_type(); string type; - if(!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)){ + if (!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)) + { NNFUSION_LOG(ERROR) << "Get element type failed"; return false; } @@ -188,7 +189,8 @@ bool ExtractGraphSignature::extract_output(std::shared_ptr c tu->out.push_back(tv); string type; - if(!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)){ + if (!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)) + { NNFUSION_LOG(ERROR) << "Get element type failed"; return false; } diff --git a/src/nnfusion/frontend/onnx_import/op/pad.hpp b/src/nnfusion/frontend/onnx_import/op/pad.hpp index a89a09133..7c65adb38 100644 --- a/src/nnfusion/frontend/onnx_import/op/pad.hpp +++ b/src/nnfusion/frontend/onnx_import/op/pad.hpp @@ -22,16 +22,17 @@ namespace nnfusion /* * since opset 11, 'pads' and 'value' have been moved from attributes to inputs * */ - if (node_proto.attribute_size() == 1){ - + if (node_proto.attribute_size() == 1) + { auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0); auto padding_gnode = GetInputNode(all_ng_nodes, node_proto, 1); std::vector paddings; bool status = GetValueFromNGraphOp(padding_gnode, &paddings); NNFUSION_CHECK(status); - NNFUSION_CHECK(paddings.size() % 2 == 0) - << "Constant node for paddings does not have an even number of elements"; + NNFUSION_CHECK(paddings.size() % 2 == 0) << "Constant node for paddings " + "does not have an even number " + "of elements"; nnfusion::Shape padding_below(paddings.size() / 2); nnfusion::Shape padding_above(paddings.size() / 2); @@ -48,10 +49,11 @@ namespace nnfusion std::make_shared(input_gnode->get_element_type(), nnfusion::Shape{}, std::vector{"0"}); - auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); + auto pad_val_gnode = + m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); - auto pad_op = - std::make_shared(padding_below, padding_above, padding_interior); + auto pad_op = std::make_shared( + padding_below, padding_above, padding_interior); pad_op->set_name(node_proto.output(0)); auto pad_gnode = @@ -59,7 +61,9 @@ namespace nnfusion NamedNodeVector ret{{node_proto.output(0), pad_gnode}}; return ret; - }else{ + } + else + { cout << "meet pad op" << endl; /* for pad op, 0: mode, 1: pads, 2: constant * we can use attr.name() to get the name of the attr @@ -70,7 +74,8 @@ namespace nnfusion auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0); const onnx::AttributeProto& modeAttr = node_proto.attribute(0); cout << modeAttr.name() << endl; - if(modeAttr.s() != "constant") NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s(); + if (modeAttr.s() != "constant") + NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s(); const onnx::AttributeProto& padAttr = node_proto.attribute(1); cout << padAttr.name() << endl; for (int i = 0; i < 8; ++i) @@ -85,7 +90,8 @@ namespace nnfusion std::make_shared(input_gnode->get_element_type(), nnfusion::Shape{}, std::vector{"0"}); - auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); + auto pad_val_gnode = + m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); nnfusion::Shape padding_below(4); nnfusion::Shape padding_above(4); nnfusion::Shape padding_interior(4); @@ -95,11 +101,11 @@ namespace nnfusion for (int i = 0; i < 4; ++i) { padding_below[i] = padAttr.ints(i); - padding_above[i] = padAttr.ints(i+4); + padding_above[i] = padAttr.ints(i + 4); } - auto pad_op = - std::make_shared(padding_below, padding_above, padding_interior); + auto pad_op = std::make_shared( + padding_below, padding_above, padding_interior); pad_op->set_name(node_proto.output(0)); auto pad_gnode = diff --git a/src/nnfusion/frontend/onnx_import/util/util.hpp b/src/nnfusion/frontend/onnx_import/util/util.hpp index 21c82e078..fb833c146 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.hpp +++ b/src/nnfusion/frontend/onnx_import/util/util.hpp @@ -143,13 +143,15 @@ namespace nnfusion } template <> - inline std::vector get_data(const onnx::TensorProto& tensor){ - + inline std::vector get_data(const onnx::TensorProto& tensor) + { if (tensor.has_raw_data()) { - return __get_raw_data(tensor.raw_data()); - }else{ - NNFUSION_LOG(NNFUSION_WARNING) << "Have no raw data" << endl ; + return __get_raw_data(tensor.raw_data()); + } + else + { + NNFUSION_LOG(NNFUSION_WARNING) << "Have no raw data" << endl; } if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT16) @@ -191,9 +193,8 @@ namespace nnfusion NNFUSION_CHECK_FAIL() << "invalid data type: " << onnx::TensorProto_DataType_Name( - static_cast(tensor.data_type())); + static_cast(tensor.data_type())); return std::vector(); - } template <> From bf746059a3152ff89b6efad4d3e438459885ebb6 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 21 Jul 2022 19:48:39 +0800 Subject: [PATCH 3/4] [Fix/Feat] Correct the fp16 inference of resnet50.onnx (#433) Co-authored-by: Wenxiang Hu <8460860+wenxcs@users.noreply.github.com> --- .../frontend/onnx_import/util/util.hpp | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/src/nnfusion/frontend/onnx_import/util/util.hpp b/src/nnfusion/frontend/onnx_import/util/util.hpp index 21c82e078..23a136f01 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.hpp +++ b/src/nnfusion/frontend/onnx_import/util/util.hpp @@ -148,52 +148,7 @@ namespace nnfusion if (tensor.has_raw_data()) { return __get_raw_data(tensor.raw_data()); - }else{ - NNFUSION_LOG(NNFUSION_WARNING) << "Have no raw data" << endl ; } - - if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT16) - { - nnfusion::Shape shape{std::begin(tensor.dims()), std::end(tensor.dims())}; - size_t num_element = shape_size(shape); - std::vector raw_data = __get_data(tensor.int32_data()); - std::vector ret((num_element + 1) / 2); - uint32_t* src_p = (uint32_t*)raw_data.data(); - uint16_t* dst_p = (uint16_t*)ret.data(); - for (size_t i = 0; i < num_element; i++) - { - NNFUSION_CHECK((src_p[i] & 0xFFFF0000) == 0); - dst_p[i] = src_p[i] & 0x0000FFFF; - } - if (num_element % 2 == 1) - { - dst_p[num_element] = 0; - } - - return ret; - } - if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT) - { - return __get_data(tensor.float_data()); - } - if (tensor.data_type() == onnx::TensorProto_DataType_INT32) - { - return __get_data(tensor.int32_data()); - } - if (tensor.data_type() == onnx::TensorProto_DataType_INT64) - { - return __get_data(tensor.int64_data()); - } - if (tensor.data_type() == onnx::TensorProto_DataType_UINT64) - { - return __get_data(tensor.uint64_data()); - } - NNFUSION_CHECK_FAIL() - << "invalid data type: " - << onnx::TensorProto_DataType_Name( - static_cast(tensor.data_type())); - return std::vector(); - } template <> From 324f7a924a9cda4cc8b073374033ca467880eb27 Mon Sep 17 00:00:00 2001 From: Jilong Xue Date: Tue, 26 Jul 2022 11:27:03 +0800 Subject: [PATCH 4/4] Update README.md --- artifacts/README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/artifacts/README.md b/artifacts/README.md index a629ddc3e..d5e05324a 100644 --- a/artifacts/README.md +++ b/artifacts/README.md @@ -1,5 +1,8 @@ # OSDI'20 Artifacts Evaluation -OSDI'20 Artifact Evaluation of paper #292, titled "[Rammer: Enabling Holistic Deep Learning Compiler Optimizations with rTasks](https://www.usenix.org/conference/osdi20/presentation/ma)". +- OSDI'20 Artifact Evaluation of paper #292, titled "[Rammer: Enabling Holistic Deep Learning Compiler Optimizations with rTasks](https://www.usenix.org/conference/osdi20/presentation/ma)". +Please refer to the [osdi20_artifact branch](https://github.com/microsoft/nnfusion/tree/osdi20_artifact/artifacts)** -**Please refer to the [osdi20_artifact branch](https://github.com/microsoft/nnfusion/tree/osdi20_artifact/artifacts)** \ No newline at end of file + +- OSDI'22 Artifact Evaluation of paper #158, titled "[Roller: Fast and Efficient Tensor Compilation for Deep Learning](https://www.usenix.org/conference/osdi22/presentation/zhu)". +Please refer to the [osdi22_artifact branch](https://github.com/microsoft/nnfusion/tree/osdi22_artifact/artifacts)**