diff --git a/src/nnfusion/engine/device/cpu.cpp b/src/nnfusion/engine/device/cpu.cpp index 05f10c346..a7562269e 100644 --- a/src/nnfusion/engine/device/cpu.cpp +++ b/src/nnfusion/engine/device/cpu.cpp @@ -22,6 +22,7 @@ #include "nnfusion/engine/pass/graph/multi_reshape_folding_pass.hpp" #include "nnfusion/engine/pass/graph/op_inplace_pass.hpp" #include "nnfusion/engine/pass/graph/pattern_substitution.hpp" +#include "nnfusion/engine/pass/graph/purify_graph_pass.hpp" #include "nnfusion/engine/pass/graph/runtime_const_folding_pass.hpp" #include "nnfusion/engine/pass/graph/vector_dot_transpose_pass.hpp" @@ -46,6 +47,7 @@ CpuEngine::CpuEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); diff --git a/src/nnfusion/engine/device/cuda.cpp b/src/nnfusion/engine/device/cuda.cpp index 7bb4dfee6..4d34bb9b7 100644 --- a/src/nnfusion/engine/device/cuda.cpp +++ b/src/nnfusion/engine/device/cuda.cpp @@ -23,6 +23,7 @@ #include "nnfusion/engine/pass/graph/multi_reshape_folding_pass.hpp" #include "nnfusion/engine/pass/graph/op_inplace_pass.hpp" #include "nnfusion/engine/pass/graph/pattern_substitution.hpp" +#include "nnfusion/engine/pass/graph/purify_graph_pass.hpp" #include "nnfusion/engine/pass/graph/reduce_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/runtime_const_folding_pass.hpp" #include "nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.hpp" @@ -49,6 +50,7 @@ CudaEngine::CudaEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); diff --git a/src/nnfusion/engine/device/hlsl.cpp b/src/nnfusion/engine/device/hlsl.cpp index 740dcb6a8..250963f55 100644 --- a/src/nnfusion/engine/device/hlsl.cpp +++ b/src/nnfusion/engine/device/hlsl.cpp @@ -25,6 +25,7 @@ #include "nnfusion/engine/pass/graph/multi_reshape_folding_pass.hpp" #include "nnfusion/engine/pass/graph/op_inplace_pass.hpp" #include "nnfusion/engine/pass/graph/pattern_substitution.hpp" +#include "nnfusion/engine/pass/graph/purify_graph_pass.hpp" #include "nnfusion/engine/pass/graph/reduce_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/runtime_const_folding_pass.hpp" #include "nnfusion/engine/pass/graph/vector_dot_transpose_pass.hpp" @@ -52,6 +53,7 @@ HLSLEngine::HLSLEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); @@ -91,6 +93,7 @@ HLSLEngine::HLSLEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); @@ -128,6 +131,7 @@ HLSLEngine::HLSLEngine() { g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); // Kernel selection diff --git a/src/nnfusion/engine/device/rocm.cpp b/src/nnfusion/engine/device/rocm.cpp index 1ef39add2..6d17c65dc 100644 --- a/src/nnfusion/engine/device/rocm.cpp +++ b/src/nnfusion/engine/device/rocm.cpp @@ -22,6 +22,7 @@ #include "nnfusion/engine/pass/graph/multi_reshape_folding_pass.hpp" #include "nnfusion/engine/pass/graph/op_inplace_pass.hpp" #include "nnfusion/engine/pass/graph/pattern_substitution.hpp" +#include "nnfusion/engine/pass/graph/purify_graph_pass.hpp" #include "nnfusion/engine/pass/graph/reduce_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/runtime_const_folding_pass.hpp" #include "nnfusion/engine/pass/graph/vector_dot_transpose_pass.hpp" @@ -47,6 +48,7 @@ ROCmEngine::ROCmEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); diff --git a/src/nnfusion/engine/pass/graph/CMakeLists.txt b/src/nnfusion/engine/pass/graph/CMakeLists.txt index 6f4b2ac04..fc934aeba 100644 --- a/src/nnfusion/engine/pass/graph/CMakeLists.txt +++ b/src/nnfusion/engine/pass/graph/CMakeLists.txt @@ -28,6 +28,7 @@ set(SRC dot_transpose_pass.cpp reduce_fusion_pass.cpp superscaler_dataparallelism_pass.cpp + purify_graph_pass.cpp ) add_library(nnfusion_engine_pass_graph STATIC ${SRC}) diff --git a/src/nnfusion/engine/pass/graph/purify_graph_pass.cpp b/src/nnfusion/engine/pass/graph/purify_graph_pass.cpp new file mode 100644 index 000000000..3951c94b9 --- /dev/null +++ b/src/nnfusion/engine/pass/graph/purify_graph_pass.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include + +#include "nnfusion/core/operators/generic_op/generic_op.hpp" +#include "nnfusion/core/operators/op_define/fused.hpp" +#include "nnfusion/core/operators/op_define/noop.hpp" +#include "purify_graph_pass.hpp" + +using namespace nnfusion::graph; +using namespace nnfusion::pass::graph; +using namespace nnfusion::kernels; + +bool PurifyGraphPass::run_on_graph(std::shared_ptr& graph) +{ + NNFUSION_LOG(INFO) << "Purify Graph Pass started:"; + std::unordered_set> valided_nodes; + for (auto node : graph->get_ordered_ops()) + { + valided_nodes.insert(node); + } + auto all_nodes = graph->get_nodes(); + NNFUSION_LOG(INFO) << "Before: " + to_string(all_nodes.size()); + for (auto node : all_nodes) + { + if (valided_nodes.find(node) != valided_nodes.end()) + continue; + graph->remove_node(node); + } + NNFUSION_LOG(INFO) << "After: " + to_string(graph->get_nodes().size()); + NNFUSION_LOG(INFO) << "Purify Graph Pass finished"; + return true; +} \ No newline at end of file diff --git a/src/nnfusion/engine/pass/graph/purify_graph_pass.hpp b/src/nnfusion/engine/pass/graph/purify_graph_pass.hpp new file mode 100644 index 000000000..52fb039a2 --- /dev/null +++ b/src/nnfusion/engine/pass/graph/purify_graph_pass.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "graph_pass_base.hpp" +#include "nnfusion/common/common.hpp" +#include "nnfusion/engine/cache/manager.hpp" +#include "nnfusion/engine/op.hpp" +#include "nnfusion/engine/profiler/profiler.hpp" + +namespace nnfusion +{ + namespace pass + { + namespace graph + { + class PurifyGraphPass : public GraphPassBase + { + public: + bool run_on_graph(std::shared_ptr& graph) override; + }; + } // namespace graph + } // namespace pass +} // namespace nnfusion \ No newline at end of file