diff --git a/src/nnfusion/common/common.hpp b/src/nnfusion/common/common.hpp index 8c8ec2b3d..2d9d3dfe3 100644 --- a/src/nnfusion/common/common.hpp +++ b/src/nnfusion/common/common.hpp @@ -67,6 +67,7 @@ #include "nnfusion/core/operators/op_define/maximum.hpp" #include "nnfusion/core/operators/op_define/min.hpp" #include "nnfusion/core/operators/op_define/minimum.hpp" +#include "nnfusion/core/operators/op_define/mod.hpp" #include "nnfusion/core/operators/op_define/multiply.hpp" #include "nnfusion/core/operators/op_define/negative.hpp" #include "nnfusion/core/operators/op_define/not.hpp" diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp index fada9792d..d38e234f2 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp @@ -11,6 +11,7 @@ #include "nnfusion/engine/async_manager.hpp" DECLARE_string(fantares_codegen_server); +DECLARE_string(ftuning_list); namespace nnfusion { @@ -206,9 +207,14 @@ namespace nnfusion , m_antares_ke_imp(new AntaresKEImp) { GENERIC_OP_LOGGING(); + parse_tuning_list(); if (!FLAGS_fantares_codegen_server.empty()) { // NNFUSION_LOG(INFO) << "Translate for " << ctx->gnode->get_op_type(); + if (TuningList.find(ctx->gnode->get_op_type()) == TuningList.end()) + { + return; + } ir = nnfusion::op::get_translation(ctx->gnode); #if 0 @@ -287,6 +293,7 @@ namespace nnfusion << ctx->gnode->get_op_type(); log_cache.insert(ctx->gnode->get_op_type()); } + return; } kernel_info = @@ -316,6 +323,19 @@ namespace nnfusion std::string ir; bool is_memcpy = false; + bool parse_tuning_list() + { + auto tuninglist_str = FLAGS_ftuning_list; + stringstream ss(tuninglist_str); + while (ss.good()) + { + string substr; + getline(ss, substr, ','); + TuningList.insert(substr); + } + NNFUSION_LOG(INFO) << "Kernel Tuning List: " << join(TuningList, ", "); + } + protected: // map tensor names and allocate tmp tensor void process_antares_kernel_info(); @@ -323,6 +343,7 @@ namespace nnfusion std::vector kernel_info; std::unordered_map tensor_name_map; // antares tensor name : kernel tensor name + std::unordered_set TuningList; }; class CacheCudaEmitter : public CudaEmitter diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_kernelops.hpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_kernelops.hpp index df40b2c60..6ff83975a 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_kernelops.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_kernelops.hpp @@ -53,6 +53,7 @@ namespace nnfusion class GreaterEq; class Less; class LessEq; + class Mod; class Not; class Relu; class ReluBackprop; @@ -275,6 +276,13 @@ namespace nnfusion static constexpr const char* math_kernel = "x1 != 0 ? fdividef(x0, x1) : 0"; }; + template <> + struct CudaOpMap + { + static constexpr const char* op = "fmod"; + static constexpr const char* math_kernel = nullptr; + }; + template <> struct CudaOpMap { diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/elementwise.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/elementwise.cpp index 1b0e7d160..09df62403 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/elementwise.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/elementwise.cpp @@ -38,6 +38,7 @@ REGISTER_EW_KERNEL(PowerBackwardExponent) REGISTER_EW_KERNEL(Subtract) REGISTER_EW_KERNEL(Divide) REGISTER_EW_KERNEL(DivNoNan) +REGISTER_EW_KERNEL(Mod) REGISTER_EW_KERNEL(Sign) REGISTER_EW_KERNEL(Convert) REGISTER_EW_KERNEL(Equal) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/if.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/if.cpp new file mode 100644 index 000000000..9109ca43a --- /dev/null +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/if.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "if.hpp" +#include "../cuda_cudnn.hpp" +#include "convolution.hpp" + +using namespace nnfusion; +using namespace nnfusion::kernels; + +cuda::If::If(shared_ptr ctx) + : KernelEmitter(ctx) +{ + std::stringstream tag; + tag << "_IfOP"; + custom_tag = tag.str(); +} + +LanguageUnit_p cuda::If::emit_function_body() +{ + LanguageUnit_p _lu(new LanguageUnit(get_function_name())); + auto& lu = *_lu; + + // function signature: + // extern "C" __global__ void kernel(m_context->dtypes[0]* input0, m_context->dtypes[0]* input1, m_context->dtypes[2]* output0) + lu << "// TODO\n"; + return _lu; +} + +LanguageUnit_p cuda::If::emit_dependency() +{ + LanguageUnit_p _lu(new LanguageUnit(get_function_name() + "_dep")); + _lu->require(header::cuda); + return _lu; +} + +REGISTER_KERNEL_EMITTER("If", // op_name + Device(CUDA_GPU).TypeConstraint(element::f32).Priority(2), // attrs + cuda::If) \ No newline at end of file diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/if.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/if.hpp new file mode 100644 index 000000000..f7f957705 --- /dev/null +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/if.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "../cuda_emitter.hpp" +#include "../cuda_langunit.hpp" + +namespace nnfusion +{ + namespace kernels + { + namespace cuda + { + class If : public KernelEmitter + { + public: + If(shared_ptr ctx); + + LanguageUnit_p emit_function_body() override; + LanguageUnit_p emit_dependency() override; + // LanguageUnit_p emit_function_signature() override; + }; + } // namespace cuda + } // namespace kernels +} // namespace nnfusion \ No newline at end of file diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/loop.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/loop.cpp new file mode 100644 index 000000000..3e2c6c8d6 --- /dev/null +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/loop.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "loop.hpp" +#include "../cuda_cudnn.hpp" +#include "convolution.hpp" + +using namespace nnfusion; +using namespace nnfusion::kernels; + +cuda::Loop::Loop(shared_ptr ctx) + : KernelEmitter(ctx) +{ + std::stringstream tag; + tag << "_LoopOP"; + custom_tag = tag.str(); +} + +LanguageUnit_p cuda::Loop::emit_function_body() +{ + LanguageUnit_p _lu(new LanguageUnit(get_function_name())); + auto& lu = *_lu; + + // function signature: + // extern "C" __global__ void kernel(m_context->dtypes[0]* input0, m_context->dtypes[0]* input1, m_context->dtypes[2]* output0) + lu << "// TODO\n"; + return _lu; +} + +LanguageUnit_p cuda::Loop::emit_dependency() +{ + LanguageUnit_p _lu(new LanguageUnit(get_function_name() + "_dep")); + _lu->require(header::cuda); + return _lu; +} + +REGISTER_KERNEL_EMITTER("Loop", // op_name + Device(CUDA_GPU).TypeConstraint(element::f32).Priority(2), // attrs + cuda::Loop) \ No newline at end of file diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/loop.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/loop.hpp new file mode 100644 index 000000000..58ae29241 --- /dev/null +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/loop.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "../cuda_emitter.hpp" +#include "../cuda_langunit.hpp" + +namespace nnfusion +{ + namespace kernels + { + namespace cuda + { + class Loop : public KernelEmitter + { + public: + Loop(shared_ptr ctx); + + LanguageUnit_p emit_function_body() override; + LanguageUnit_p emit_dependency() override; + // LanguageUnit_p emit_function_signature() override; + }; + } // namespace cuda + } // namespace kernels +} // namespace nnfusion \ No newline at end of file diff --git a/src/nnfusion/core/operators/CMakeLists.txt b/src/nnfusion/core/operators/CMakeLists.txt index f445946fc..347692ec6 100644 --- a/src/nnfusion/core/operators/CMakeLists.txt +++ b/src/nnfusion/core/operators/CMakeLists.txt @@ -42,15 +42,18 @@ set(SRC op_define/gelu.cpp op_define/greater_eq.cpp op_define/greater.cpp + op_define/if.cpp op_define/less_eq.cpp op_define/less.cpp op_define/log.cpp + op_define/loop.cpp op_define/lrn.cpp op_define/max_pool.cpp op_define/max.cpp op_define/maximum.cpp op_define/min.cpp op_define/minimum.cpp + op_define/mod.cpp op_define/multiply.cpp op_define/negative.cpp op_define/not_equal.cpp diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp index bc53d9e27..f118fcfdd 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp @@ -44,6 +44,7 @@ static const std::unordered_map ElementOpMap = { element_op( "divnonan", "(x0 / x1).when([x1 != const(0).cast(x1.dtype())], const(0).cast(input1[].dtype()))")}, + {"Mod", element_op("fmod", "")}, {"Square", element_op("square", "x0 * x0")}, {"Negative", element_op("negative", "-x0")}, {"Select", element_op("select", "x2.when([x0 == 0], x1)")}, @@ -174,6 +175,7 @@ REGISTER_ELEM_OP(Subtract) REGISTER_ELEM_OP(Multiply) REGISTER_ELEM_OP(Divide) REGISTER_ELEM_OP(DivNoNan) +REGISTER_ELEM_OP(Mod) REGISTER_ELEM_OP(Square) REGISTER_ELEM_OP(Negative) REGISTER_ELEM_OP(Select) diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp index 6174bc824..1bc2b03b8 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp @@ -52,4 +52,4 @@ REGISTER_OP(ScatterND) {"update_layout", vector_to_string>(update_layout)}, {"output_layout", vector_to_string>(output_layout)}, }); - }); \ No newline at end of file + }); diff --git a/src/nnfusion/core/operators/op_define/if.cpp b/src/nnfusion/core/operators/op_define/if.cpp new file mode 100644 index 000000000..41f45d529 --- /dev/null +++ b/src/nnfusion/core/operators/op_define/if.cpp @@ -0,0 +1,58 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "if.hpp" + +using namespace std; +using namespace nnfusion::op; + +If::If(std::shared_ptr& then_branch_graph, + std::shared_ptr& else_branch_graph) + : Op("If") + , m_then_branch_graph(then_branch_graph) + , m_else_branch_graph(else_branch_graph) +{ +} + +void If::validate_and_infer_types(std::shared_ptr gnode) +{ + nnfusion::Shape cond_shape = gnode->get_input_shape(0); + nnfusion::element::Type cond_et = gnode->get_input_element_type(0); + NNFUSION_CHECK(cond_shape.size() == 0) + << "The condition tensor of the If operation mush be scalar."; + NNFUSION_CHECK(cond_et == nnfusion::element::boolean) + << "The condition tensor of the If operation mush be boolean."; + + auto then_branch_outputs = m_then_branch_graph->get_outputs(); + auto else_branch_outputs = m_else_branch_graph->get_outputs(); + NNFUSION_CHECK(then_branch_outputs.size() == else_branch_outputs.size()) + << "The outputs in the then_branch and else_branch must have the same shape and " + "same data type."; + for (size_t i = 0; i < then_branch_outputs.size(); i++) + { + NNFUSION_CHECK(then_branch_outputs[i]->get_shape() == else_branch_outputs[i]->get_shape() && + then_branch_outputs[i]->get_element_type() == + else_branch_outputs[i]->get_element_type()) + << "The outputs in the then_branch and else_branch must have the same shape and " + "same data type."; + + gnode->set_output_type_and_shape( + i, then_branch_outputs[i]->get_element_type(), then_branch_outputs[i]->get_shape()); + } +} \ No newline at end of file diff --git a/src/nnfusion/core/operators/op_define/if.hpp b/src/nnfusion/core/operators/op_define/if.hpp new file mode 100644 index 000000000..acd093aaf --- /dev/null +++ b/src/nnfusion/core/operators/op_define/if.hpp @@ -0,0 +1,49 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "../op.hpp" +#include "nnfusion/core/graph/graph.hpp" + +namespace nnfusion +{ + namespace op + { + /// \brief If control-flow operation, with same definition as https://github.com/onnx/onnx/blob/master/docs/Changelog.md#If-1. + class If : public Op + { + public: + /// \brief Constructs an if operation + /// + /// \param then_branch_graph The then_branch graph.
+ /// `[f]` + /// \param else_branch_graph The else_branch graph.
+ /// `[f]` + If(std::shared_ptr& then_branch_graph, + std::shared_ptr& else_branch_graph); + + void validate_and_infer_types(std::shared_ptr gnode) override; + + protected: + std::shared_ptr m_then_branch_graph; + std::shared_ptr m_else_branch_graph; + }; + } +} diff --git a/src/nnfusion/core/operators/op_define/loop.cpp b/src/nnfusion/core/operators/op_define/loop.cpp new file mode 100644 index 000000000..45a1c8a5a --- /dev/null +++ b/src/nnfusion/core/operators/op_define/loop.cpp @@ -0,0 +1,55 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "loop.hpp" + +using namespace std; +using namespace nnfusion::op; + +Loop::Loop(std::shared_ptr& loop_body_graph, + const std::vector& output_shapes, + const std::vector& output_types) + : Op("Loop") + , m_loop_body_graph(loop_body_graph) + , m_output_shapes(output_shapes) + , m_output_types(output_types) +{ +} + +void Loop::validate_and_infer_types(std::shared_ptr gnode) +{ + nnfusion::Shape trip_shape = gnode->get_input_shape(0); + nnfusion::element::Type trip_et = gnode->get_input_element_type(0); + NNFUSION_CHECK(trip_shape.size() == 0) + << "The trip-count tensor of the Loop operation mush be scalar."; + NNFUSION_CHECK(trip_et == nnfusion::element::i64) + << "The trip-count tensor of the Loop operation mush be boolean."; + + nnfusion::Shape cond_shape = gnode->get_input_shape(1); + nnfusion::element::Type cond_et = gnode->get_input_element_type(1); + NNFUSION_CHECK(cond_shape.size() == 0) + << "The condition tensor of the Loop operation mush be scalar."; + NNFUSION_CHECK(cond_et == nnfusion::element::boolean) + << "The condition tensor of the Loop operation mush be boolean."; + + for (size_t i = 0; i < gnode->get_output_size(); i++) + { + gnode->set_output_type_and_shape(i, m_output_types[i], m_output_shapes[i]); + } +} \ No newline at end of file diff --git a/src/nnfusion/core/operators/op_define/loop.hpp b/src/nnfusion/core/operators/op_define/loop.hpp new file mode 100644 index 000000000..d531e6a74 --- /dev/null +++ b/src/nnfusion/core/operators/op_define/loop.hpp @@ -0,0 +1,49 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "../op.hpp" +#include "nnfusion/core/graph/graph.hpp" + +namespace nnfusion +{ + namespace op + { + /// \brief Loop control-flow operation, with same definition as https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Loop-11. + class Loop : public Op + { + public: + /// \brief Constructs an if operation + /// + /// \param loop_body_graph The loop body graph.
+ /// `[f]` + Loop(std::shared_ptr& loop_body_graph, + const std::vector& output_shapes, + const std::vector& output_types); + + void validate_and_infer_types(std::shared_ptr gnode) override; + + protected: + std::shared_ptr m_loop_body_graph; + std::vector m_output_shapes; + std::vector m_output_types; + }; + } // namespace op +} // namespace nnfusion diff --git a/src/nnfusion/core/operators/op_define/mod.cpp b/src/nnfusion/core/operators/op_define/mod.cpp new file mode 100644 index 000000000..9d7b96d69 --- /dev/null +++ b/src/nnfusion/core/operators/op_define/mod.cpp @@ -0,0 +1,26 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Microsoft (c) 2019, NNFusion Team + +#include "mod.hpp" + +using namespace nnfusion::op; + +Mod::Mod() + : ElementwiseArithmetic("Mod") +{ +} \ No newline at end of file diff --git a/src/nnfusion/core/operators/op_define/mod.hpp b/src/nnfusion/core/operators/op_define/mod.hpp new file mode 100644 index 000000000..d2d7723f3 --- /dev/null +++ b/src/nnfusion/core/operators/op_define/mod.hpp @@ -0,0 +1,38 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Microsoft (c) 2019, NNFusion Team + +#pragma once + +#include "nnfusion/core/operators/util/elementwise_arithmetic.hpp" + +namespace nnfusion +{ + namespace op + { + /// \brief Elementwise mod operation. + class Mod : public ElementwiseArithmetic + { + public: + /// \brief Constructs a mod operation. + Mod(); + + protected: + virtual bool is_commutative() override { return false; } + }; + }; // namespace op +} // namespace nnfusion diff --git a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp index b9a699c5a..0ca99ef76 100644 --- a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp +++ b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp @@ -20,6 +20,7 @@ DEFINE_int64(fkernel_tuning_steps, 0, "Enable automatic kernel tuning for maximu DEFINE_string(ftuning_blocklist, "", "List of op types that skip kernel tuning pass, e.g., \"Softmax,Add\""); +DEFINE_string(ftuning_list, "", "List of op types for kernel tuning pass, e.g., \"Softmax,Add\""); DEFINE_string(fantares_perf_file, "./antares_perf.csv", "File to save Antares kernel performance."); DECLARE_bool(fantares_mode); DECLARE_string(fantares_codegen_server); @@ -130,6 +131,7 @@ void dump_perf(std::string filename, std::pair>, std::vector>> get_tuning_candidates(std::shared_ptr& graph, + const std::unordered_set tuning_list, const std::unordered_set block_list, std::unordered_map& ir2cnt) { @@ -147,6 +149,12 @@ std::pair>, std::vector(); NNFUSION_CHECK(n_device_type != UNKNOWN); + // filter ops not in TuningList + if (tuning_list.size() > 0 && tuning_list.find(gnode->get_op_type()) == tuning_list.end()) + { + continue; + } + // filter ops in BlockList if (block_list.find(gnode->get_op_type()) != block_list.end()) { @@ -255,11 +263,31 @@ bool KernelTuning::parse_block_list() NNFUSION_LOG(INFO) << "Kernel Tuning BlockList: " << join(BlockList, ", "); } +bool KernelTuning::parse_tuning_list() +{ + auto tuninglist_str = FLAGS_ftuning_list; + stringstream ss(tuninglist_str); + while (ss.good()) + { + string substr; + getline(ss, substr, ','); + TuningList.insert(substr); + } + NNFUSION_LOG(INFO) << "Kernel Tuning List: " << join(TuningList, ", "); +} + bool KernelTuning::run_on_graph(std::shared_ptr& graph) { if (FLAGS_fantares_mode) { + parse_tuning_list(); parse_block_list(); + for (auto item : TuningList) + { + NNFUSION_CHECK(BlockList.find(item) == BlockList.end()) + << "Kernel Tuning Pass: There are same operators in TuningList and " + "TuningBlockList."; + } // register antares kernels anyway here in case kernel selection pass will use them register_antares_kernel(); } @@ -274,7 +302,7 @@ bool KernelTuning::run_on_graph(std::shared_ptr& graph) std::vector> tuning_kernels; std::unordered_map ir2cnt; std::vector> nodes; - std::tie(nodes, tuned_kernels) = get_tuning_candidates(graph, BlockList, ir2cnt); + std::tie(nodes, tuned_kernels) = get_tuning_candidates(graph, TuningList, BlockList, ir2cnt); for (auto gnode : nodes) { if (!(*gnode)["DeviceType"].is_valid()) diff --git a/src/nnfusion/engine/pass/graph/kernel_tuning.hpp b/src/nnfusion/engine/pass/graph/kernel_tuning.hpp index 1d49761e2..67cf3a592 100644 --- a/src/nnfusion/engine/pass/graph/kernel_tuning.hpp +++ b/src/nnfusion/engine/pass/graph/kernel_tuning.hpp @@ -22,11 +22,13 @@ namespace nnfusion private: bool parse_block_list(); + bool parse_tuning_list(); bool insert_to_kernel_cache( const std::vector>& nodes); private: std::unordered_set BlockList; + std::unordered_set TuningList; }; } } diff --git a/src/nnfusion/frontend/onnx_import/CMakeLists.txt b/src/nnfusion/frontend/onnx_import/CMakeLists.txt index 55310383d..01305906c 100644 --- a/src/nnfusion/frontend/onnx_import/CMakeLists.txt +++ b/src/nnfusion/frontend/onnx_import/CMakeLists.txt @@ -32,6 +32,10 @@ add_library(onnx_import STATIC core/attribute.cpp core/node.hpp core/node.cpp + op/if.hpp + op/if.cpp + op/loop.hpp + op/loop.cpp op/no.hpp op/no.cpp op/slice.hpp diff --git a/src/nnfusion/frontend/onnx_import/core/attribute.hpp b/src/nnfusion/frontend/onnx_import/core/attribute.hpp index 3596bd61c..12d6b0ae0 100644 --- a/src/nnfusion/frontend/onnx_import/core/attribute.hpp +++ b/src/nnfusion/frontend/onnx_import/core/attribute.hpp @@ -208,6 +208,33 @@ namespace nnfusion } } + template <> + inline onnx::GraphProto get_value(const onnx::AttributeProto& attribute) + { + NNFUSION_CHECK(attribute.type() == onnx::AttributeProto_AttributeType_GRAPH) + << "invalid attribute type : " + << onnx::AttributeProto_AttributeType_Name(attribute.type()); + + return attribute.g(); + } + + template <> + inline std::vector + get_value(const onnx::AttributeProto& attribute) + { + switch (attribute.type()) + { + case onnx::AttributeProto_AttributeType_GRAPH: + return {onnx::GraphProto{attribute.g()}}; + case onnx::AttributeProto_AttributeType_GRAPHS: + return {std::begin(attribute.graphs()), std::end(attribute.graphs())}; + default: + NNFUSION_CHECK_FAIL() + << "invalid attribute type : " + << onnx::AttributeProto_AttributeType_Name(attribute.type()); + } + } + } // namespace detail class Attribute diff --git a/src/nnfusion/frontend/onnx_import/onnx.cpp b/src/nnfusion/frontend/onnx_import/onnx.cpp index cbe4fb6e5..a154cffd3 100644 --- a/src/nnfusion/frontend/onnx_import/onnx.cpp +++ b/src/nnfusion/frontend/onnx_import/onnx.cpp @@ -51,42 +51,42 @@ namespace nnfusion load_onnx_model(const std::string& path, const std::unordered_map& dim_params) { - NNFUSION_LOG(INFO) << "Optimizing ONNX Graph with External Tool " - "(models/pytorch2onnx/ort_run_frozen.py)"; - string optimized_filename = string(tmpnam(nullptr)); + // NNFUSION_LOG(INFO) << "Optimizing ONNX Graph with External Tool " + // "(models/pytorch2onnx/ort_run_frozen.py)"; + // string optimized_filename = string(tmpnam(nullptr)); string m_path = path; - string script_path = - nnfusion::codegen::get_file_from_templates("onnx/ort_run_frozen.py"); - string cmd = "python3 " + script_path + - " --graph_optimization_level ORT_ENABLE_BASIC " - "--warmup 1 --iters 0 --provider CPUExecutionProvider --file " + - path + " --optimized_model_filepath " + optimized_filename; - if (dim_params.size() > 0) - { - string dim_params_str = " --symbolic_dims \'{"; - for (auto& it : dim_params) - { - if (dim_params_str != " --symbolic_dims \'{") - { - dim_params_str += ", "; - } - dim_params_str += "\"" + it.first + "\": " + to_string(it.second); - } - dim_params_str += "}\'"; - cmd += dim_params_str; - } - int sys_ret = system(cmd.c_str()); - std::ifstream opt_fin(optimized_filename.c_str()); - if (sys_ret == 0 && opt_fin.is_open()) - { - m_path = optimized_filename; - } - else - { - NNFUSION_LOG(NNFUSION_WARNING) - << "Failed to optimize ONNX Graph with external tool, please " - "check error messages reported by the tool, fallback"; - } + // string script_path = + // nnfusion::codegen::get_file_from_templates("onnx/ort_run_frozen.py"); + // string cmd = "python3 " + script_path + + // " --graph_optimization_level ORT_ENABLE_BASIC " + // "--warmup 1 --iters 0 --provider CPUExecutionProvider --file " + + // path + " --optimized_model_filepath " + optimized_filename; + // if (dim_params.size() > 0) + // { + // string dim_params_str = " --symbolic_dims \'{"; + // for (auto& it : dim_params) + // { + // if (dim_params_str != " --symbolic_dims \'{") + // { + // dim_params_str += ", "; + // } + // dim_params_str += "\"" + it.first + "\": " + to_string(it.second); + // } + // dim_params_str += "}\'"; + // cmd += dim_params_str; + // } + // int sys_ret = system(cmd.c_str()); + // std::ifstream opt_fin(optimized_filename.c_str()); + // if (sys_ret == 0 && opt_fin.is_open()) + // { + // m_path = optimized_filename; + // } + // else + // { + // NNFUSION_LOG(NNFUSION_WARNING) + // << "Failed to optimize ONNX Graph with external tool, please " + // "check error messages reported by the tool, fallback"; + // } std::ifstream ifs{m_path, std::ios::in | std::ios::binary}; NNFUSION_CHECK(ifs.is_open()) << "failure opening file:" + path; @@ -99,10 +99,10 @@ namespace nnfusion auto graph = load_onnx_model(ifs, model_dir, dim_params); - if (opt_fin.is_open()) - { - remove(optimized_filename.c_str()); - } + // if (opt_fin.is_open()) + // { + // remove(optimized_filename.c_str()); + // } return graph; } diff --git a/src/nnfusion/frontend/onnx_import/op/constant.hpp b/src/nnfusion/frontend/onnx_import/op/constant.hpp index f3dd0bfc9..2e3d4f86a 100644 --- a/src/nnfusion/frontend/onnx_import/op/constant.hpp +++ b/src/nnfusion/frontend/onnx_import/op/constant.hpp @@ -48,7 +48,8 @@ namespace nnfusion static const std::map( const element::Type&, const Tensor&)>> - the_map = {{element::f32, __make_constant_op}, + the_map = {{element::boolean, __make_constant_op}, + {element::f32, __make_constant_op}, {element::f64, __make_constant_op}, {element::i32, __make_constant_op}, {element::i64, __make_constant_op}, diff --git a/src/nnfusion/frontend/onnx_import/op/if.cpp b/src/nnfusion/frontend/onnx_import/op/if.cpp new file mode 100644 index 000000000..912dc4ac4 --- /dev/null +++ b/src/nnfusion/frontend/onnx_import/op/if.cpp @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "if.hpp" +#include "../util/graph_convert.hpp" +#include "../util/util.hpp" +#include "nnfusion/core/operators/op_define/if.hpp" + +using namespace nnfusion::frontend::onnx_import; + +/* +class Model(torch.jit.ScriptModule): + def __init__(self): + super(Model, self).__init__() + + @torch.jit.script_method + def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, cond: int): + t = x - y + if cond > 0: + if cond > 1: + return x * z + else: + return x * t + else: + return x + y + +x = torch.ones([2, 2], dtype=torch.float32) +y = torch.ones([2, 2], dtype=torch.float32) +z = torch.ones([2, 2], dtype=torch.float32) + + +ir_version: 6 +producer_name: "pytorch" +producer_version: "1.6" +graph { + node { + input: "x.1" + input: "y.1" + output: "4" + name: "Sub_0" + op_type: "Sub" + } + node { + output: "5" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "cond.1" + input: "5" + output: "6" + name: "Greater_2" + op_type: "Greater" + } + node { + input: "6" + output: "7" + name: "If_3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + output: "8" + name: "Constant_4" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "cond.1" + input: "8" + output: "9" + name: "Greater_5" + op_type: "Greater" + } + node { + input: "9" + output: "10" + name: "If_6" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "x.1" + input: "z.1" + output: "11" + name: "Mul_7" + op_type: "Mul" + } + name: "torch-jit-export2" + output { + name: "11" + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "x.1" + input: "4" + output: "12" + name: "Mul_8" + op_type: "Mul" + } + name: "torch-jit-export3" + output { + name: "12" + } + } + type: GRAPH + } + } + name: "torch-jit-export1" + output { + name: "10" + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "x.1" + input: "y.1" + output: "13" + name: "Add_9" + op_type: "Add" + } + name: "torch-jit-export4" + output { + name: "13" + } + } + type: GRAPH + } + } + name: "torch-jit-export" + input { + name: "x.1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "y.1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "z.1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "cond.1" + type { + tensor_type { + elem_type: 7 + shape { + } + } + } + } + output { + name: "7" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 11 +} +*/ + +namespace nnfusion +{ + namespace frontend + { + namespace onnx_import + { + namespace set_1 + { + NamedNodeVector TranslateIfOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph) + { + NNFUSION_CHECK_FAIL() + << "This is a placeholder convert_func, please use the real one."; + return {}; + } + + NamedNodeVector TranslateIfOp( + const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph, + const std::unordered_map& domain_convert_func_map, + const string& model_dir, + const std::unordered_map& domain2version, + const std::unordered_map& dim_params) + { + Node node(node_proto); + onnx::GraphProto then_branch_graph_proto = + node.get_attribute_value("then_branch"); + onnx::GraphProto else_branch_graph_proto = + node.get_attribute_value("else_branch"); + + onnx::NodeProto completed_node_proto(node_proto); + auto then_branch_graph_inputs = extract_input(then_branch_graph_proto); + auto else_branch_graph_inputs = extract_input(else_branch_graph_proto); + std::unordered_set node_inputs; + for (size_t i = 0; i < node_proto.input_size(); i++) + { + node_inputs.insert(node_proto.input(i)); + } + for (auto item : then_branch_graph_inputs) + { + if (node_inputs.find(item) == node_inputs.end()) + { + completed_node_proto.add_input(item); + node_inputs.insert(item); + } + } + for (auto item : else_branch_graph_inputs) + { + if (node_inputs.find(item) == node_inputs.end()) + { + completed_node_proto.add_input(item); + node_inputs.insert(item); + } + } + auto input_indexes = GetAllInputIndex(all_ng_nodes, completed_node_proto); + + // process then_branch graph and else_branch_graph + std::shared_ptr then_branch_graph; + std::shared_ptr else_branch_graph; + { + then_branch_graph_proto = complete_graphproto(then_branch_graph_proto); + GraphProtoConvert then_branch_graph_convert(then_branch_graph_proto, + domain_convert_func_map, + model_dir, + domain2version, + dim_params, + all_ng_nodes, + true); + then_branch_graph = then_branch_graph_convert.get_graph(); + + else_branch_graph_proto = complete_graphproto(else_branch_graph_proto); + GraphProtoConvert else_branch_graph_convert(else_branch_graph_proto, + domain_convert_func_map, + model_dir, + domain2version, + dim_params, + all_ng_nodes, + true); + else_branch_graph = else_branch_graph_convert.get_graph(); + } + + auto if_op = std::make_shared(then_branch_graph, else_branch_graph); + if_op->set_name(node_proto.name()); + auto if_gnode = m_graph->add_node_and_edge(if_op, input_indexes); + + NamedNodeVector ret; + for (size_t i = 0; i < node_proto.output_size(); i++) + { + ret.push_back(NamedNode(node_proto.output(i), if_gnode, i)); + } + + return ret; + + // for (auto item : all_ng_nodes) + // { + // std::cout << "NodeMap[" << item.first << "]: " << item.second.size() << std::endl; + // } + + // // std::cout << then_branch_graph_proto.DebugString() << std::endl; + // // std::cout << else_branch_graph_proto.DebugString() << std::endl; + + // std::vector model_inputs; + // for (auto i = 0; i < model_proto.graph().input_size(); i++) + // { + // model_inputs.push_back(model_proto.graph().input(i)); + // } + + // for (auto i = 0; i < model_inputs.size(); i++) + // { + // auto input = then_branch_graph_proto.add_input(); + // input->CopyFrom(model_inputs[i]); + // } + // for (auto i = 0; i < model_inputs.size(); i++) + // { + // auto input = else_branch_graph_proto.add_input(); + // input->CopyFrom(model_inputs[i]); + // } + + // onnx::ModelProto then_branch_model_proto = onnx::ModelProto(model_proto); + // then_branch_model_proto.set_allocated_graph(&then_branch_graph_proto); + // // GraphConvert then_branch_converter = + // // GraphConvert(then_branch_model_proto, {}, "", all_ng_nodes); + // // auto then_branch_graph = then_branch_converter.get_graph(); + // // auto then_branch_graph = std::make_shared() + + // onnx::ModelProto else_branch_model_proto = onnx::ModelProto(model_proto); + // else_branch_model_proto.set_allocated_graph(&else_branch_graph_proto); + // // GraphConvert else_branch_converter = + // // GraphConvert(else_branch_model_proto, {}, "", all_ng_nodes); + // // auto else_branch_graph = else_branch_converter.get_graph(); + + // // exit(1); + } + + } // namespace set_1 + + } //namespace onnx_import + + } // namespace frontend + +} // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/op/if.hpp b/src/nnfusion/frontend/onnx_import/op/if.hpp new file mode 100644 index 000000000..3844d5270 --- /dev/null +++ b/src/nnfusion/frontend/onnx_import/op/if.hpp @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "core/node.hpp" + +namespace nnfusion +{ + namespace frontend + { + namespace onnx_import + { + namespace set_1 + { + NamedNodeVector TranslateIfOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph); + + NamedNodeVector TranslateIfOp( + const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph, + const std::unordered_map& domain_convert_func_map, + const string& model_dir, + const std::unordered_map& domain2version, + const std::unordered_map& dim_params = {}); + + } // namespace set_1 + + } //namespace onnx_import + + } // namespace frontend + +} // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/op/loop.cpp b/src/nnfusion/frontend/onnx_import/op/loop.cpp new file mode 100644 index 000000000..67bceed59 --- /dev/null +++ b/src/nnfusion/frontend/onnx_import/op/loop.cpp @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "loop.hpp" +#include "../util/graph_convert.hpp" +#include "../util/util.hpp" +#include "nnfusion/core/operators/op_define/loop.hpp" + +using namespace nnfusion::frontend::onnx_import; + +/* +class Model(torch.jit.ScriptModule): + def __init__(self): + super(Model, self).__init__() + + @torch.jit.script_method + def forward(self, x: torch.Tensor, num_loop: int): + ret = x + for i in range(num_loop): + ret = ret + x + return ret + +x = torch.ones([2, 2], dtype=torch.float32) +a = torch.tensor(5) + + +ir_version: 6 +producer_name: "pytorch" +producer_version: "1.9" +graph { + node { + output: "2" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 9 + raw_data: "\001" + } + type: TENSOR + } + } + node { + input: "num_loop.1" + input: "2" + input: "x.1" + output: "3" + name: "Loop_1" + op_type: "Loop" + attribute { + name: "body" + g { + node { + input: "ret.9" + input: "x.1" + output: "7" + name: "Add_2" + op_type: "Add" + } + node { + input: "2" + output: "8" + name: "Identity_3" + op_type: "Identity" + } + name: "torch-jit-export1" + input { + name: "i" + type { + tensor_type { + elem_type: 7 + shape { + } + } + } + } + input { + name: "cond" + type { + tensor_type { + elem_type: 9 + shape { + } + } + } + } + input { + name: "ret.9" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "8" + type { + tensor_type { + elem_type: 9 + shape { + } + } + } + } + output { + name: "7" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + } + type: GRAPH + } + } + name: "torch-jit-export" + input { + name: "x.1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "num_loop.1" + type { + tensor_type { + elem_type: 7 + shape { + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 11 +} +*/ + +namespace nnfusion +{ + namespace frontend + { + namespace onnx_import + { + namespace set_1 + { + NamedNodeVector TranslateLoopOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph) + { + NNFUSION_CHECK_FAIL() + << "This is a placeholder convert_func, please use the real one."; + return {}; + } + + NamedNodeVector TranslateLoopOp( + const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph, + const std::unordered_map& domain_convert_func_map, + const string& model_dir, + const std::unordered_map& domain2version, + const std::unordered_map& dim_params) + { + Node node(node_proto); + onnx::GraphProto loop_body_graph_proto = + node.get_attribute_value("body"); + + onnx::NodeProto completed_node_proto(node_proto); + auto loop_body_graph_inputs = extract_input(loop_body_graph_proto); + std::unordered_set node_inputs; + for (size_t i = 0; i < node_proto.input_size(); i++) + { + node_inputs.insert(node_proto.input(i)); + } + for (const auto& input_proto : loop_body_graph_proto.input()) + { + node_inputs.insert(input_proto.name()); + } + for (auto item : loop_body_graph_inputs) + { + if (node_inputs.find(item) == node_inputs.end()) + { + completed_node_proto.add_input(item); + node_inputs.insert(item); + } + } + + auto input_indexes = GetAllInputIndex(all_ng_nodes, completed_node_proto); + + // process loop_body_graph + std::shared_ptr loop_body_graph; + { + loop_body_graph_proto = complete_graphproto(loop_body_graph_proto); + std::cout << loop_body_graph_proto.DebugString() << std::endl; + GraphProtoConvert loop_body_graph_convert(loop_body_graph_proto, + domain_convert_func_map, + model_dir, + domain2version, + dim_params, + all_ng_nodes, + true); + loop_body_graph = loop_body_graph_convert.get_graph(); + } + + std::vector output_shapes; + std::vector output_types; + for (size_t i = 1; i < loop_body_graph_proto.output().size(); i++) + { + ValueInfo output_value_info(loop_body_graph_proto.output()[i], dim_params); + output_shapes.push_back(output_value_info.get_shape()); + output_types.push_back(output_value_info.get_element_type()); + } + + auto loop_op = + std::make_shared(loop_body_graph, output_shapes, output_types); + loop_op->set_name(node_proto.name()); + auto loop_gnode = m_graph->add_node_and_edge( + loop_op, input_indexes, /* output_size */ node_proto.output_size()); + + NamedNodeVector ret; + for (size_t i = 0; i < node_proto.output_size(); i++) + { + ret.push_back(NamedNode(node_proto.output(i), loop_gnode, i)); + } + + return ret; + } + + } // namespace set_1 + + } //namespace onnx_import + + } // namespace frontend + +} // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/op/loop.hpp b/src/nnfusion/frontend/onnx_import/op/loop.hpp new file mode 100644 index 000000000..efc87dbf7 --- /dev/null +++ b/src/nnfusion/frontend/onnx_import/op/loop.hpp @@ -0,0 +1,187 @@ +/* +ir_version: 6 +producer_name: "pytorch" +producer_version: "1.6" +graph { + node { + output: "1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + output: "2" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 2 + dims: 2 + data_type: 1 + raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + } + type: TENSOR + } + } + node { + input: "1" + output: "3" + name: "Cast_2" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + } + node { + input: "num_loop.1" + input: "3" + input: "2" + output: "4" + name: "Loop_3" + op_type: "Loop" + attribute { + name: "body" + g { + node { + output: "8" + name: "Constant_4" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 2 + dims: 2 + data_type: 1 + raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + } + type: TENSOR + } + } + node { + input: "ret.6" + input: "8" + output: "9" + name: "Add_5" + op_type: "Add" + } + node { + input: "1" + output: "10" + name: "Cast_6" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + } + name: "torch-jit-export1" + input { + name: "i" + type { + tensor_type { + elem_type: 7 + shape { + } + } + } + } + input { + name: "cond" + type { + tensor_type { + elem_type: 9 + } + } + } + input { + name: "ret.6" + } + output { + name: "10" + } + output { + name: "9" + } + } + type: GRAPH + } + } + name: "torch-jit-export" + input { + name: "num_loop.1" + type { + tensor_type { + elem_type: 7 + shape { + } + } + } + } + output { + name: "4" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 11 +} +*/ + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "core/node.hpp" + +namespace nnfusion +{ + namespace frontend + { + namespace onnx_import + { + namespace set_1 + { + NamedNodeVector TranslateLoopOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph); + + NamedNodeVector TranslateLoopOp( + const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph, + const std::unordered_map& domain_convert_func_map, + const string& model_dir, + const std::unordered_map& domain2version, + const std::unordered_map& dim_params = {}); + + } // namespace set_1 + + } //namespace onnx_import + + } // namespace frontend + +} // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/ops_bridge.cpp b/src/nnfusion/frontend/onnx_import/ops_bridge.cpp index 4bd101045..50b4f2282 100644 --- a/src/nnfusion/frontend/onnx_import/ops_bridge.cpp +++ b/src/nnfusion/frontend/onnx_import/ops_bridge.cpp @@ -46,10 +46,12 @@ #include "op/gather.hpp" #include "op/gemm.hpp" #include "op/identity.hpp" +#include "op/if.hpp" #include "op/index_reduce.hpp" #include "op/layer_norm.hpp" #include "op/leaky_relu.hpp" #include "op/log_softmax.hpp" +#include "op/loop.hpp" #include "op/lstm.hpp" #include "op/matmul.hpp" #include "op/memory_copy.hpp" @@ -191,12 +193,17 @@ namespace nnfusion REGISTER_OPERATOR("Greater", 1, TranslateBinaryOp); //REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid); REGISTER_OPERATOR("Identity", 1, TranslateIdentityOp); + // REGISTER_OPERATOR("If", 1, TranslateIfOp); + REGISTER_OPERATOR( + "If", 1, TranslateIdentityOp); // TODO(lingm): fix convert_func map REGISTER_OPERATOR("LayerNormalization", 1, TranslateLayerNormalizationOp); REGISTER_OPERATOR("LayerNormalizationGrad", 1, TranslateLayerNormalizationGradOp); REGISTER_OPERATOR("LeakyRelu", 1, TranslateLeakyReluOp); REGISTER_OPERATOR("Less", 1, TranslateBinaryOp); REGISTER_OPERATOR("Log", 1, TranslateUnaryOp); REGISTER_OPERATOR("LogSoftmax", 1, TranslateLogSoftmaxOp); + REGISTER_OPERATOR( + "Loop", 1, TranslateIdentityOp); // TODO(lingm): fix convert_func map //REGISTER_OPERATOR("LRN", 1, lrn); REGISTER_OPERATOR("LSTM", 1, TranslateLstmOp); REGISTER_OPERATOR("MatMul", 1, TranslateMatmulOp); @@ -206,6 +213,7 @@ namespace nnfusion REGISTER_OPERATOR("MemcpyFromHost", 1, TranslateMemcpyFromHostOp); REGISTER_OPERATOR("MemcpyToHost", 1, TranslateMemcpyToHostOp); REGISTER_OPERATOR("Min", 1, TranslateLegacyBinaryOp); + REGISTER_OPERATOR("Mod", 1, TranslateBinaryOp); REGISTER_OPERATOR("Mul", 1, TranslateLegacyBinaryOp); REGISTER_OPERATOR("Mul", 7, TranslateBinaryOp); REGISTER_OPERATOR("Neg", 1, TranslateUnaryOp); @@ -230,6 +238,7 @@ namespace nnfusion REGISTER_OPERATOR("Relu", 1, TranslateUnaryOp); REGISTER_OPERATOR("Reshape", 1, TranslateReshapeOp); REGISTER_OPERATOR("ReshapeGrad", 1, TranslateReshapeGradOp); + REGISTER_OPERATOR("ScatterND", 11, TranslateScatterNDOp); //REGISTER_OPERATOR("Selu", 1, selu); REGISTER_OPERATOR("Shape", 1, TranslateShapeOp); REGISTER_OPERATOR("Sigmoid", 1, TranslateUnaryOp); @@ -270,7 +279,6 @@ namespace nnfusion REGISTER_OPERATOR("Resize", 1, TranslateResizeOp); REGISTER_OPERATOR("Upsample", 1, TranslateResizeOp); REGISTER_OPERATOR("Where", 1, TranslateWhereOp); - REGISTER_OPERATOR("ScatterND", 11, TranslateScatterNDOp); REGISTER_OPERATOR("DepthToSpace", 1, TranslateDepthToSpaceOp); REGISTER_OPERATOR("DepthToSpace", 11, TranslateDepthToSpaceOp); REGISTER_DOMAIN_OPERATOR("org.pytorch.aten", "roll", 1, TranslateRollOp); diff --git a/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp b/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp index f4683623d..031003910 100644 --- a/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp +++ b/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp @@ -22,9 +22,12 @@ #include "graph_convert.hpp" #include #include +#include "../op/if.hpp" +#include "../op/loop.hpp" #include "nnfusion/core/operators/generic_op/generic_op.hpp" #include "op/custom_op.hpp" #include "ops_bridge.hpp" +#include "util.hpp" DECLARE_bool(ftraining_mode); @@ -224,6 +227,263 @@ namespace nnfusion tensor.set_raw_data(raw_data); } } + } // namespace + + GraphProtoConvert::GraphProtoConvert( + const onnx::GraphProto& graph_proto, + const std::unordered_map& domain_convert_func_map, + const string& model_dir, + const std::unordered_map& domain2version, + const std::unordered_map& dim_params, + const NodeMap& node_map, + bool flag_subgraph) + : onnx_graph_proto(&graph_proto) + , m_domain_convert_func_map(domain_convert_func_map) + , m_model_dir(model_dir) + , m_domain2version(domain2version) + , m_dim_params(dim_params) + , m_node_map(node_map) + , m_flag_subgraph(flag_subgraph) + { + m_graph = std::make_shared(); + + NNFUSION_CHECK(onnx_graph_proto->sparse_initializer_size() == 0) + << "sparse_initializer not supported"; + + for (const auto& output : onnx_graph_proto->output()) + { + m_output_names.insert(output.name()); + } + + for (auto tensor : onnx_graph_proto->initializer()) + { + if (tensor.has_name()) + { + move_external_to_rawdata(tensor, m_model_dir); + if (FLAGS_ftraining_mode) + { + element::Type type; + ONNXDataTypeToNNFusionElementType( + static_cast(tensor.data_type()), &type); + std::shared_ptr input_gnode; + auto tensor_op = std::make_shared( + type, + Shape(std::begin(tensor.dims()), std::end(tensor.dims())), + false, + true); + tensor_op->set_name(tensor.name()); + input_gnode = + m_graph->add_node_and_edge(tensor_op, graph::GNodeVector({})); + m_node_map[tensor.name()] = {GNodeIndex{input_gnode}}; + } + else + { + auto tensor_op = make_constant_op( + static_cast(tensor.data_type()), + Shape(std::begin(tensor.dims()), std::end(tensor.dims())), + Tensor{tensor}); + tensor_op->set_name(tensor.name()); + auto tensor_gnode = + m_graph->add_node_and_edge(tensor_op, graph::GNodeVector({})); + m_node_map[tensor.name()] = {GNodeIndex{tensor_gnode}}; + } + } + } + // Process all ONNX graph inputs, convert them to NNFusion nodes + if (m_flag_subgraph) + { + for (const auto& input_proto : onnx_graph_proto->input()) + { + std::shared_ptr input_gnode; + auto it = m_node_map.find(input_proto.name()); + std::shared_ptr input_op; + if (it != std::end(m_node_map)) + { + input_op = std::make_shared( + it->second[0].get_element_type(), it->second[0].get_shape()); + } + else + { + ValueInfo input_value_info(input_proto, m_dim_params); + input_op = std::make_shared( + input_value_info.get_element_type(), input_value_info.get_shape()); + } + input_op->set_name(input_proto.name()); + input_gnode = m_graph->add_node_and_edge(input_op, graph::GNodeVector({})); + m_node_map[input_proto.name()] = {GNodeIndex{input_gnode}}; + if (m_output_names.find(input_gnode->get_name()) != m_output_names.end()) + { + // TODO: should specify which output of current gnode + m_graph_outputs.emplace_back(input_gnode); + } + } + } + else + { + for (const auto& input_proto : onnx_graph_proto->input()) + { + ValueInfo input_value_info(input_proto, m_dim_params); + std::shared_ptr input_gnode; + // TODO: parameter might have default value in initializer + auto it = m_node_map.find(input_proto.name()); + if (it != std::end(m_node_map)) + { + NNFUSION_LOG(NNFUSION_WARNING) + << "Ignore input: " << input_proto.name() + << ", because it has a default initializers"; + NNFUSION_CHECK(it->second.size() == 1) + << "Multi outputs found for initializer " << input_proto.name(); + if (it->second[0].get_element_type() != + input_value_info.get_element_type()) + { + auto cast_op = std::make_shared( + input_value_info.get_element_type()); + cast_op->set_name(input_proto.name()); + auto input_gnode = m_graph->add_node_and_edge(cast_op, it->second); + m_node_map[input_proto.name()] = {GNodeIndex{input_gnode}}; + if (m_output_names.find(input_gnode->get_name()) != + m_output_names.end()) + { + // TODO: should specify which output of current gnode + m_graph_outputs.emplace_back(input_gnode); + } + } + } + else + { + auto input_op = std::make_shared( + input_value_info.get_element_type(), input_value_info.get_shape()); + input_op->set_name(input_proto.name()); + input_gnode = + m_graph->add_node_and_edge(input_op, graph::GNodeVector({})); + m_node_map[input_proto.name()] = {GNodeIndex{input_gnode}}; + if (m_output_names.find(input_gnode->get_name()) != + m_output_names.end()) + { + // TODO: should specify which output of current gnode + m_graph_outputs.emplace_back(input_gnode); + } + } + } + } + + // Process ONNX graph nodes, convert to nGraph nodes + // sorted to avoid non-stardard model + std::vector onnx_nodes(std::begin(onnx_graph_proto->node()), + std::end(onnx_graph_proto->node())); + std::unordered_set external_values{ + ""}; // values provided by initializers/params, empty string means option input + std::transform(std::begin(onnx_graph_proto->initializer()), + std::end(onnx_graph_proto->initializer()), + std::inserter(external_values, external_values.begin()), + [](onnx::TensorProto t) -> std::string { return t.name(); }); + std::transform(std::begin(onnx_graph_proto->input()), + std::end(onnx_graph_proto->input()), + std::inserter(external_values, external_values.begin()), + [](onnx::ValueInfoProto v) -> std::string { return v.name(); }); + if (!is_sorted(onnx_nodes, external_values)) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Resorting ONNX nodes..."; + onnx_nodes = tp_sort(onnx_nodes, external_values); + } + + for (const auto& node_proto : onnx_nodes) + { + auto results = convert_node(node_proto); + for (auto& named_gnode : results) + { + m_node_map[named_gnode.name] = {named_gnode.gnode_index}; + + if (m_output_names.find(named_gnode.name) != m_output_names.end()) + { + // TODO: should specify which output of current gnode + named_gnode.gnode_index.gnode->get_output_tensor_ptr(0)->set_name( + named_gnode.name); + m_graph_outputs.emplace_back(named_gnode.gnode_index.gnode); + } + } + } + + // Sort nGraph output nodes in the order of ONNX output nodes + std::vector output_names; + for (const auto& output : onnx_graph_proto->output()) + { + output_names.push_back(output.name()); + } + sort(m_graph_outputs.begin(), + m_graph_outputs.end(), + [&output_names](const std::shared_ptr& a, + const std::shared_ptr& b) { + return std::find(output_names.begin(), output_names.end(), a->get_name()) < + std::find(output_names.begin(), output_names.end(), b->get_name()); + }); + + m_graph->set_default_parameters(); + m_graph->set_outputs(m_graph_outputs); + } + + NamedNodeVector GraphProtoConvert::convert_node(const onnx::NodeProto& node_proto) + { + NNFUSION_LOG(INFO) << "convert node: " << node_proto.name(); + NamedNodeVector ret; + if (node_proto.op_type() == "If") + { + ret = set_1::TranslateIfOp(node_proto, + m_node_map, + m_graph, + m_domain_convert_func_map, + m_model_dir, + m_domain2version, + m_dim_params); + } + else if (node_proto.op_type() == "Loop") + { + ret = set_1::TranslateLoopOp(node_proto, + m_node_map, + m_graph, + m_domain_convert_func_map, + m_model_dir, + m_domain2version, + m_dim_params); + } + else + { + ret = get_convert_func(node_proto.op_type(), + node_proto.domain())(node_proto, m_node_map, m_graph); + const auto& convert_func = + get_convert_func(node_proto.op_type(), node_proto.domain()); + if (convert_func) + { + ret = convert_func(node_proto, m_node_map, m_graph); + } + else + { + NNFUSION_LOG(NNFUSION_WARNING) + << "No translator for " + << (node_proto.domain().empty() ? "" : node_proto.domain() + ".") + << node_proto.op_type() << ", try to convert to generic op"; + ret = TranslateCustomOp(node_proto, + m_node_map, + m_graph, + m_domain2version.at(node_proto.domain())); + } + } + for (int i = 0; i < ret.size(); i++) + { + NNFUSION_LOG(INFO) << "node " << node_proto.name() << ", output " << ret[i].name + << ", shape " << ret[i].gnode_index.get_shape(); + } + return std::move(ret); + } + + const ConvertFunc& GraphProtoConvert::get_convert_func(const std::string& name, + const std::string& domain) const + { + if (m_domain_convert_func_map.find(domain) == m_domain_convert_func_map.end() || + m_domain_convert_func_map.at(domain).find(name) == + m_domain_convert_func_map.at(domain).end()) + return EMPTY_CONVERT_FUNC; + return m_domain_convert_func_map.at(domain).at(name); } GraphConvert::GraphConvert(const onnx::ModelProto& model_proto, @@ -233,9 +493,9 @@ namespace nnfusion , onnx_graph_proto(&(model_proto.graph())) , m_graph(new nnfusion::graph::Graph()) , m_dim_params(dim_params) - , model_dir(model_dir) + , m_model_dir(model_dir) { - // print_model_proto(model_proto); + print_model_proto(model_proto); // Note: onnx connect nodes by tensor's name instead of op name /* @@ -324,7 +584,7 @@ namespace nnfusion id.domain(), OperatorsBridge::get_convert_func_map( id.version(), (id.domain() == "ai.onnx" ? "" : id.domain()))); - domain2version[id.domain()] = id.version(); + m_domain2version[id.domain()] = id.version(); } // onnx.proto(.3): the empty string ("") for domain or absence of opset_import field // implies the operator set that is defined as part of the ONNX specification. @@ -333,97 +593,10 @@ namespace nnfusion { m_domain_convert_func_map.emplace( "", OperatorsBridge::get_convert_func_map(ONNX_OPSET_VERSION, "")); - domain2version[""] = ONNX_OPSET_VERSION; + m_domain2version[""] = ONNX_OPSET_VERSION; } - m_graph = std::make_shared(); - - NNFUSION_CHECK(onnx_graph_proto->sparse_initializer_size() == 0) - << "sparse_initializer not supported"; - - for (const auto& output : onnx_graph_proto->output()) - { - m_output_names.insert(output.name()); - } - - for (auto tensor : onnx_graph_proto->initializer()) - { - if (tensor.has_name()) - { - move_external_to_rawdata(tensor, model_dir); - if (FLAGS_ftraining_mode) - { - element::Type type; - ONNXDataTypeToNNFusionElementType( - static_cast(tensor.data_type()), &type); - std::shared_ptr input_gnode; - auto tensor_op = std::make_shared( - type, - Shape(std::begin(tensor.dims()), std::end(tensor.dims())), - false, - true); - tensor_op->set_name(tensor.name()); - input_gnode = - m_graph->add_node_and_edge(tensor_op, graph::GNodeVector({})); - m_node_map[tensor.name()] = {GNodeIndex{input_gnode}}; - } - else - { - auto tensor_op = make_constant_op( - static_cast(tensor.data_type()), - Shape(std::begin(tensor.dims()), std::end(tensor.dims())), - Tensor{tensor}); - tensor_op->set_name(tensor.name()); - auto tensor_gnode = - m_graph->add_node_and_edge(tensor_op, graph::GNodeVector({})); - m_node_map[tensor.name()] = {GNodeIndex{tensor_gnode}}; - } - } - } - // Process all ONNX graph inputs, convert them to NNFusion nodes - for (const auto& input_proto : onnx_graph_proto->input()) - { - ValueInfo input_value_info(input_proto, m_dim_params); - std::shared_ptr input_gnode; - // TODO: parameter might have default value in initializer - auto it = m_node_map.find(input_proto.name()); - if (it != std::end(m_node_map)) - { - NNFUSION_LOG(NNFUSION_WARNING) << "Ignore input: " << input_proto.name() - << ", because it has a default initializers"; - NNFUSION_CHECK(it->second.size() == 1) - << "Multi outputs found for initializer " << input_proto.name(); - if (it->second[0].get_element_type() != input_value_info.get_element_type()) - { - auto cast_op = - std::make_shared(input_value_info.get_element_type()); - cast_op->set_name(input_proto.name()); - auto input_gnode = m_graph->add_node_and_edge(cast_op, it->second); - m_node_map[input_proto.name()] = {GNodeIndex{input_gnode}}; - if (m_output_names.find(input_gnode->get_name()) != - m_output_names.end()) - { - // TODO: should specify which output of current gnode - m_graph_outputs.emplace_back(input_gnode); - } - } - } - else - { - auto input_op = std::make_shared( - input_value_info.get_element_type(), input_value_info.get_shape()); - input_op->set_name(input_proto.name()); - input_gnode = m_graph->add_node_and_edge(input_op, graph::GNodeVector({})); - m_node_map[input_proto.name()] = {GNodeIndex{input_gnode}}; - if (m_output_names.find(input_gnode->get_name()) != m_output_names.end()) - { - // TODO: should specify which output of current gnode - m_graph_outputs.emplace_back(input_gnode); - } - } - } - - // Verify that ONNX graph contains only nodes of available operator types + // // Verify that ONNX graph contains only nodes of available operator types // { // std::unordered_map domain2version; // for (const auto& id : onnx_model_proto->opset_import()) @@ -457,99 +630,61 @@ namespace nnfusion // } // } - // Process ONNX graph nodes, convert to nGraph nodes - // sorted to avoid non-stardard model - std::vector onnx_nodes(std::begin(onnx_graph_proto->node()), - std::end(onnx_graph_proto->node())); - std::unordered_set external_values{ - ""}; // values provided by initializers/params, empty string means option input - std::transform(std::begin(onnx_graph_proto->initializer()), - std::end(onnx_graph_proto->initializer()), - std::inserter(external_values, external_values.begin()), - [](onnx::TensorProto t) -> std::string { return t.name(); }); - std::transform(std::begin(onnx_graph_proto->input()), - std::end(onnx_graph_proto->input()), - std::inserter(external_values, external_values.begin()), - [](onnx::ValueInfoProto v) -> std::string { return v.name(); }); - if (!is_sorted(onnx_nodes, external_values)) - { - NNFUSION_LOG(NNFUSION_WARNING) << "Resorting ONNX nodes..."; - onnx_nodes = tp_sort(onnx_nodes, external_values); - } + // m_controlflow_graphproto_map = construct_controlflow_graphproto(*onnx_graph_proto); - for (const auto& node_proto : onnx_nodes) - { - auto results = convert_node(node_proto); - for (auto& named_gnode : results) - { - m_node_map[named_gnode.name] = {named_gnode.gnode_index}; - - if (m_output_names.find(named_gnode.name) != m_output_names.end()) - { - // TODO: should specify which output of current gnode - named_gnode.gnode_index.gnode->get_output_tensor_ptr(0)->set_name( - named_gnode.name); - m_graph_outputs.emplace_back(named_gnode.gnode_index.gnode); - } - } - } - - // Sort nGraph output nodes in the order of ONNX output nodes - std::vector output_names; - for (const auto& output : onnx_graph_proto->output()) - { - output_names.push_back(output.name()); - } - sort(m_graph_outputs.begin(), - m_graph_outputs.end(), - [&output_names](const std::shared_ptr& a, - const std::shared_ptr& b) { - return std::find(output_names.begin(), output_names.end(), a->get_name()) < - std::find(output_names.begin(), output_names.end(), b->get_name()); - }); - - m_graph->set_default_parameters(); - m_graph->set_outputs(m_graph_outputs); + m_graph = convert_graph(*onnx_graph_proto); NNFUSION_LOG(INFO) << "convert graph done"; } - NamedNodeVector GraphConvert::convert_node(const onnx::NodeProto& node_proto) + std::shared_ptr + GraphConvert::convert_graph(const onnx::GraphProto& graph_proto, + const NodeMap& node_map) { - NNFUSION_LOG(INFO) << "convert node: " << node_proto.name(); - const auto& convert_func = - get_convert_func(node_proto.op_type(), node_proto.domain()); - NamedNodeVector ret; - if (convert_func) - { - ret = convert_func(node_proto, m_node_map, m_graph); - } - else - { - NNFUSION_LOG(NNFUSION_WARNING) - << "No translator for " - << (node_proto.domain().empty() ? "" : node_proto.domain() + ".") - << node_proto.op_type() << ", try to convert to generic op"; - ret = TranslateCustomOp( - node_proto, m_node_map, m_graph, domain2version.at(node_proto.domain())); - } - for (int i = 0; i < ret.size(); i++) - { - NNFUSION_LOG(INFO) << "node " << node_proto.name() << ", output " << ret[i].name - << ", shape " << ret[i].gnode_index.get_shape(); - } - return std::move(ret); + GraphProtoConvert converter = GraphProtoConvert(graph_proto, + m_domain_convert_func_map, + m_model_dir, + m_domain2version, + m_dim_params, + node_map, + false); + return converter.get_graph(); } - const ConvertFunc& GraphConvert::get_convert_func(const std::string& name, - const std::string& domain) const - { - if (m_domain_convert_func_map.find(domain) == m_domain_convert_func_map.end() || - m_domain_convert_func_map.at(domain).find(name) == - m_domain_convert_func_map.at(domain).end()) - return EMPTY_CONVERT_FUNC; - return m_domain_convert_func_map.at(domain).at(name); - } + // std::unordered_map + // GraphConvert::construct_controlflow_graphproto(const onnx::GraphProto& graph_proto) + // { + // // currently, this function does not support nested controlflow + // std::unordered_map controlflow_graphproto_map; + // // std::vector unsorted_nodes(std::begin(onnx_graph_proto->node()), + // // std::end(onnx_graph_proto->node())); + // // auto tensorproto_map = extract_tensorproto(graph_proto); + // for (auto node_proto : graph_proto.node()) + // { + // if (node_proto.op_type() == "If") + // { + // Node node(node_proto); + // controlflow_graphproto_map[node_proto.name() + "_If_then_branch"] = + // complete_graphproto( + // node.get_attribute_value("then_branch")); + // controlflow_graphproto_map[node_proto.name() + "_If_else_branch"] = + // complete_graphproto( + // node.get_attribute_value("else_branch")); + // } + // else if (node_proto.op_type() == "Loop") + // { + // Node node(node_proto); + // controlflow_graphproto_map[node_proto.name() + "_Loop_body"] = + // complete_graphproto(node.get_attribute_value("body")); + // } + // // else if (node_proto.op_type() == "Scan") + // // { + // // // + // // } + // } + + // return controlflow_graphproto_map; + // } bool GraphConvert::is_operator_available(const onnx::NodeProto& node_proto) const { diff --git a/src/nnfusion/frontend/onnx_import/util/graph_convert.hpp b/src/nnfusion/frontend/onnx_import/util/graph_convert.hpp index 096e18a4f..5ee346678 100644 --- a/src/nnfusion/frontend/onnx_import/util/graph_convert.hpp +++ b/src/nnfusion/frontend/onnx_import/util/graph_convert.hpp @@ -38,6 +38,51 @@ namespace nnfusion { namespace onnx_import { + class GraphProtoConvert + { + public: + GraphProtoConvert( + const onnx::GraphProto& graph_proto, + const std::unordered_map& domain_convert_func_map, + const string& model_dir, + const std::unordered_map& domain2version, + const std::unordered_map& dim_params = {}, + const NodeMap& _node_map = NodeMap(), + bool flag_subgraph = false); + + std::shared_ptr get_graph() { return m_graph; } + const onnx::GraphProto& get_onnx_proto_graph() const { return *onnx_graph_proto; } + NamedNodeVector convert_node(const onnx::NodeProto& node_proto); + + /// \brief Access an operator object by its type name and domain name + /// The function will return the operator object if it exists, or report an error + /// in case of domain or operator absence. + /// \param name type name of the operator object, + /// \param domain domain name of the operator object. + /// \return Reference to the operator object. + const ConvertFunc& get_convert_func(const std::string& name, + const std::string& domain) const; + + private: + const onnx::GraphProto* onnx_graph_proto; + + std::shared_ptr m_graph; + + std::unordered_map m_domain_convert_func_map; + + NodeMap m_node_map; + + // TODO: to be removed + std::set m_output_names; + + graph::GNodeVector m_graph_outputs; + + std::unordered_map m_dim_params; + std::string m_model_dir; + std::unordered_map m_domain2version; + + bool m_flag_subgraph; + }; class GraphConvert { public: @@ -66,16 +111,19 @@ namespace nnfusion return onnx_model_proto->producer_version(); } - NamedNodeVector convert_node(const onnx::NodeProto& node_proto); + /// \brief Convert ONNX::GraphProto to nnfusion graph + /// \param graph_proto ONNX GraphProto + /// \param _node_map pre-provided node_map, empty by default + /// \return std::shared_ptr + std::shared_ptr + convert_graph(const onnx::GraphProto& graph_proto, + const NodeMap& _node_map = NodeMap()); - /// \brief Access an operator object by its type name and domain name - /// The function will return the operator object if it exists, or report an error - /// in case of domain or operator absence. - /// \param name type name of the operator object, - /// \param domain domain name of the operator object. - /// \return Reference to the operator object. - const ConvertFunc& get_convert_func(const std::string& name, - const std::string& domain) const; + // /// \brief Construct complete GraphProtos for sub-graphs in control-flow nodes (e.g., If, Loop) by adding the missing information (i.e., inputs) of the GraphProto, which could be processed by GraphProtoConvert to get nnfusion graph + // /// \param graph_proto the graph_proto of the ONNX model + // /// \returns unordered_map + // std::unordered_map + // construct_controlflow_graphproto(const onnx::GraphProto& graph_proto); /// \brief Check availability of operator base on NodeProto. /// \return `true` if the operator is available, otherwise it returns `false`. @@ -89,16 +137,18 @@ namespace nnfusion std::unordered_map m_domain_convert_func_map; - NodeMap m_node_map; + // std::unordered_map m_controlflow_graphproto_map; + + // NodeMap m_node_map; // TODO: to be removed - std::set m_output_names; + // std::set m_output_names; - graph::GNodeVector m_graph_outputs; + // graph::GNodeVector m_graph_outputs; std::unordered_map m_dim_params; - std::string model_dir; - std::unordered_map domain2version; + std::string m_model_dir; + std::unordered_map m_domain2version; }; } // namespace onnx_import } // namespace frontend diff --git a/src/nnfusion/frontend/onnx_import/util/util.cpp b/src/nnfusion/frontend/onnx_import/util/util.cpp index d6f52653d..2fb514ff7 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.cpp +++ b/src/nnfusion/frontend/onnx_import/util/util.cpp @@ -280,6 +280,67 @@ namespace nnfusion name, std::vector(kernel_shape.size(), 1UL)); } + std::unordered_set extract_input(const onnx::GraphProto& graph_proto) + { + std::unordered_set node_inputs; + std::unordered_set node_outputs; + + for (auto node_proto : graph_proto.node()) + { + for (size_t i = 0; i < node_proto.input_size(); i++) + { + node_inputs.insert(node_proto.input(i)); + } + for (size_t i = 0; i < node_proto.output_size(); i++) + { + node_outputs.insert(node_proto.output(i)); + } + } + + std::unordered_set graph_inputs; + for (auto item : node_inputs) + { + if (node_outputs.find(item) == node_outputs.end()) + { + graph_inputs.insert(item); + } + } + + return graph_inputs; + } + + onnx::GraphProto complete_graphproto(const onnx::GraphProto& graph_proto) + { + onnx::GraphProto completed_graphproto(graph_proto); + + auto all_inputs = extract_input(graph_proto); + std::unordered_set existing_inputs; + for (auto input_proto : graph_proto.input()) + { + existing_inputs.insert(input_proto.name()); + } + + std::unordered_set missing_inputs; + for (auto input : all_inputs) + { + if (existing_inputs.find(input) == existing_inputs.end()) + { + missing_inputs.insert(input); + // std::cout << input << std::endl; + } + } + + for (auto item : missing_inputs) + { + auto input = completed_graphproto.add_input(); + input->set_name(item); + } + + // std::cout << completed_graphproto.DebugString() << std::endl; + + return completed_graphproto; + } + } // namespace onnx_import } // namespace frontend } // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/util/util.hpp b/src/nnfusion/frontend/onnx_import/util/util.hpp index 871ab4801..40891233d 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.hpp +++ b/src/nnfusion/frontend/onnx_import/util/util.hpp @@ -160,6 +160,54 @@ namespace nnfusion return __get_data(tensor.uint64_data()); } + template <> + inline std::vector get_data(const onnx::TensorProto& tensor) + { + // NNFUSION_CHECK(tensor.has_raw_data()) << "Data type char only supports raw_data"; + // return __get_raw_data(tensor.raw_data()); + + // the following is for control-flow test + if (tensor.has_raw_data()) + { + return __get_raw_data(tensor.raw_data()); + } + // NNFUSION_CHECK(tensor.data_type() == onnx::TensorProto_DataType_INT32) + // << "invalid data type: " + // << onnx::TensorProto_DataType_Name( + // static_cast(tensor.data_type())); + auto tmp = __get_data(tensor.int32_data()); + std::vector ret; + for (auto item : tmp) + { + ret.push_back((char)item); + } + return ret; + } + + template <> + inline std::vector get_data(const onnx::TensorProto& tensor) + { + // NNFUSION_CHECK(tensor.has_raw_data()) << "Data type boolean only supports raw_data"; + // return __get_raw_data(tensor.raw_data()); + + // the following is for control-flow test + if (tensor.has_raw_data()) + { + return __get_raw_data(tensor.raw_data()); + } + // NNFUSION_CHECK(tensor.data_type() == onnx::TensorProto_DataType_INT32) + // << "invalid data type: " + // << onnx::TensorProto_DataType_Name( + // static_cast(tensor.data_type())); + auto tmp = __get_data(tensor.int32_data()); + std::vector ret; + for (auto item : tmp) + { + ret.push_back((bool)item); + } + return ret; + } + /// \brief Fill specified range with monotonic sequence. /// /// \param[in] first The iterator to the beginning of the range. @@ -181,7 +229,7 @@ namespace nnfusion *first = init_value; } } - } + } // namespace detail class Tensor; class Node; @@ -322,6 +370,9 @@ namespace nnfusion const std::string& name, const Shape& kernel_shape); + std::unordered_set extract_input(const onnx::GraphProto& graph_proto); + onnx::GraphProto complete_graphproto(const onnx::GraphProto& graph_proto); + } // namespace onnx_import } // namespace frontend } // namespace nnfusion