Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
Signed-off-by: Evgeniia Nugmanova <[email protected]>
  • Loading branch information
jane-intel committed Jan 3, 2025
1 parent 4ebb6ed commit e4df58f
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 94 deletions.
42 changes: 21 additions & 21 deletions src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static void reg_pattern_wrap_type(py::module m) {

wrap_type.def(py::init([](const std::string& type_name, const ov::Output<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
nullptr,
ov::pass::pattern::op::Predicate(),
ov::OutputVector{input});
}),
py::arg("type_name"),
Expand All @@ -106,7 +106,7 @@ static void reg_pattern_wrap_type(py::module m) {

wrap_type.def(py::init([](const std::string& type_name, const std::shared_ptr<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
nullptr,
ov::pass::pattern::op::Predicate(),
ov::OutputVector{input});
}),
py::arg("type_name"),
Expand Down Expand Up @@ -165,7 +165,9 @@ static void reg_pattern_wrap_type(py::module m) {
)");

wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), nullptr, inputs);
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
ov::pass::pattern::op::Predicate(),
inputs);
}),
py::arg("type_name"),
py::arg("inputs"),
Expand All @@ -181,7 +183,7 @@ static void reg_pattern_wrap_type(py::module m) {

wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
nullptr,
ov::pass::pattern::op::Predicate(),
ov::as_output_vector(inputs));
}),
py::arg("type_name"),
Expand Down Expand Up @@ -264,7 +266,7 @@ static void reg_pattern_wrap_type(py::module m) {

wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
nullptr,
ov::pass::pattern::op::Predicate(),
ov::OutputVector{input});
}),
py::arg("type_names"),
Expand All @@ -281,7 +283,7 @@ static void reg_pattern_wrap_type(py::module m) {

wrap_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
nullptr,
ov::pass::pattern::op::Predicate(),
ov::OutputVector{input});
}),
py::arg("type_names"),
Expand Down Expand Up @@ -343,7 +345,9 @@ static void reg_pattern_wrap_type(py::module m) {
)");

wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), nullptr, inputs);
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
ov::pass::pattern::op::Predicate(),
inputs);
}),
py::arg("type_names"),
py::arg("inputs"),
Expand All @@ -359,7 +363,7 @@ static void reg_pattern_wrap_type(py::module m) {

wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
nullptr,
ov::pass::pattern::op::Predicate(),
ov::as_output_vector(inputs));
}),
py::arg("type_names"),
Expand Down Expand Up @@ -501,8 +505,7 @@ static void reg_pattern_optional(py::module m) {

optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names),
ov::OutputVector{input},
nullptr);
ov::OutputVector{input});
}),
py::arg("type_names"),
py::arg("input"),
Expand All @@ -518,8 +521,7 @@ static void reg_pattern_optional(py::module m) {

optional_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names),
ov::OutputVector{input},
nullptr);
ov::OutputVector{input});
}),
py::arg("type_names"),
py::arg("input"),
Expand All @@ -533,13 +535,12 @@ static void reg_pattern_optional(py::module m) {
:type input: openvino.runtime.Node
)");

optional_type.def(
py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), inputs, nullptr);
}),
py::arg("type_names"),
py::arg("inputs"),
R"(
optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), inputs);
}),
py::arg("type_names"),
py::arg("inputs"),
R"(
Create Optional with the given node type and input node.
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
Expand All @@ -551,8 +552,7 @@ static void reg_pattern_optional(py::module m) {

optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names),
ov::as_output_vector(inputs),
nullptr);
ov::as_output_vector(inputs));
}),
py::arg("type_names"),
py::arg("inputs"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"
#include "transformations_visibility.hpp"

namespace ov {
Expand Down Expand Up @@ -88,29 +89,42 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass {
* @ingroup ov_transformation_common_api
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation
*/
class ov::pass::RoPEFusion : public ov::pass::GraphRewrite {

class ov::pass::RoPEFusion : public ov::pass::ModelPass {
public:
OPENVINO_GRAPH_REWRITE_RTTI("RoPEFusion");
RoPEFusion(bool support_2d_rope = false) {
add_matcher<ov::pass::RoPEFusionFlux>();
add_matcher<ov::pass::RoPEFusionGPTNEOX>();
add_matcher<ov::pass::RoPEFusionGPTJ>();
OPENVINO_MODEL_PASS_RTTI("RoPEFusion");

explicit RoPEFusion(bool support_2d_rope = false) : support_2d_rope(support_2d_rope){};

bool run_on_model(const std::shared_ptr<ov::Model>& m) override {
auto symbolic_pipeline = ov::pass::SymbolicOptimizations(false);
auto rope_fusions = symbolic_pipeline.get_manager()->register_pass<ov::pass::GraphRewrite>();
rope_fusions->set_name("RoPEFusions");

rope_fusions->add_matcher<ov::pass::RoPEFusionFlux>();
rope_fusions->add_matcher<ov::pass::RoPEFusionGPTNEOX>();
rope_fusions->add_matcher<ov::pass::RoPEFusionGPTJ>();
// optional heads & tails are fused in separate matcher pass,
// after RoPENode has been created.
add_matcher<ov::pass::RoPEFusionCosSinPreprocess>();
add_matcher<ov::pass::RoPEFusionIOSlicing>();
add_matcher<ov::pass::RoPEFusionPreprocess>();
rope_fusions->add_matcher<ov::pass::RoPEFusionCosSinPreprocess>();
rope_fusions->add_matcher<ov::pass::RoPEFusionIOSlicing>();
rope_fusions->add_matcher<ov::pass::RoPEFusionPreprocess>();

add_matcher<ov::pass::RoPEFusionChatGLM>(0);
add_matcher<ov::pass::RoPEFusionChatGLM>(1);
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(0);
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(1);
if (support_2d_rope) {
add_matcher<ov::pass::RoPEFusionChatGLM>(0, true);
add_matcher<ov::pass::RoPEFusionChatGLM>(1, true);
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(0, true);
rope_fusions->add_matcher<ov::pass::RoPEFusionChatGLM>(1, true);
}

add_matcher<ov::pass::RoPEFusionQwen>(0);
add_matcher<ov::pass::RoPEFusionQwen>(1);
rope_fusions->add_matcher<ov::pass::RoPEFusionQwen>(0);
rope_fusions->add_matcher<ov::pass::RoPEFusionQwen>(1);

rope_fusions->add_matcher<ov::pass::RoPEShareCosSin>();

add_matcher<ov::pass::RoPEShareCosSin>();
return symbolic_pipeline.run_on_model(m);
}

protected:
bool support_2d_rope = false;
};
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() {
}

ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
m_manager = std::make_shared<pass::Manager>("Symbolic");
m_manager = std::make_shared<pass::Manager>(get_pass_config(), "Symbolic");
m_manager->set_per_pass_validation(false);

#define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass<region>(__VA_ARGS__);
Expand Down Expand Up @@ -208,7 +208,7 @@ bool ov::pass::SymbolicOptimizations::run_on_model(const std::shared_ptr<ov::Mod
pass_config->disable<EliminateSqueeze>();
pass_config->disable<EliminateUnsqueeze>();

m_manager->run_passes(m);
bool status = m_manager->run_passes(m);
ov::remove_skip_invalidation_rti(m);
return true;
return status;
}
4 changes: 4 additions & 0 deletions src/core/include/openvino/pass/pattern/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ class OPENVINO_API Matcher {
PatternValueMap& get_pattern_value_map() {
return m_pattern_map;
}
PatternSymbolMap& get_symbols() {
return m_pattern_symbols;
}
PatternValueMaps& get_pattern_value_maps() {
return m_pattern_value_maps;
}
Expand Down Expand Up @@ -198,6 +201,7 @@ class OPENVINO_API Matcher {
Output<Node> m_match_root;
Output<Node> m_pattern_node;
PatternValueMap m_pattern_map;
PatternSymbolMap m_pattern_symbols;
PatternValueMaps m_pattern_value_maps;
OutputVector m_matched_list;

Expand Down
20 changes: 13 additions & 7 deletions src/core/include/openvino/pass/pattern/op/any.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,27 @@ class OPENVINO_API Any : public Pattern {
OPENVINO_RTTI("patternAny");
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
/// shape.
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
Any(const element::Type& type, const PartialShape& s, Predicate pred, const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred) {
set_output_type(0, type, s);
}
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
: Any(type, s, Predicate(pred), wrapped_values) {}
Any(const element::Type& type, const PartialShape& s, SymbolPredicate pred, const OutputVector& wrapped_values)
: Any(type, s, Predicate(pred), wrapped_values) {}
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
: Any(type, s, pred, as_output_vector(wrapped_values)) {}

/// \brief creates a Any node containing a sub-pattern described by the type and
/// shape of \sa node.
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
Any(const Output<Node>& node, Predicate pred, const OutputVector& wrapped_values)
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
: Any(node, Predicate(pred), wrapped_values) {}
Any(const Output<Node>& node, SymbolPredicate pred, const OutputVector& wrapped_values)
: Any(node, Predicate(pred), wrapped_values) {}
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: Any(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values)) {}
: Any(node, pred, as_output_vector(wrapped_values)) {}

bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
Expand Down
9 changes: 9 additions & 0 deletions src/core/include/openvino/pass/pattern/op/any_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ class OPENVINO_API AnyOf : public Pattern {
}
set_output_type(0, type, s);
}
AnyOf(const element::Type& type, const PartialShape& s, SymbolPredicate pred, const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred) {
if (wrapped_values.size() != 1) {
OPENVINO_THROW("AnyOf expects exactly one argument");
}
set_output_type(0, type, s);
}
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
: AnyOf(
type,
Expand All @@ -43,6 +50,8 @@ class OPENVINO_API AnyOf : public Pattern {
/// shape of \sa node.
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
AnyOf(const Output<Node>& node, SymbolPredicate pred, const OutputVector& wrapped_values)
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
AnyOf(const std::shared_ptr<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
Expand Down
17 changes: 16 additions & 1 deletion src/core/include/openvino/pass/pattern/op/label.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ class OPENVINO_API Label : public Pattern {
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
set_output_type(0, type, s);
}

Label(const element::Type& type,
const PartialShape& s,
const SymbolPredicate pred,
const OutputVector& wrapped_values)
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
set_output_type(0, type, s);
}
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
: Label(
type,
Expand All @@ -56,6 +62,9 @@ class OPENVINO_API Label : public Pattern {
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
: Label(type, s, std::move(pred), OutputVector{}) {}

Label(const element::Type& type, const PartialShape& s, SymbolPredicate pred)
: Label(type, s, std::move(pred), OutputVector{}) {}

Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
: Label(type, s, as_value_predicate(std::move(pred)), OutputVector{}) {}

Expand All @@ -78,6 +87,10 @@ class OPENVINO_API Label : public Pattern {
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
Label(const Output<Node>& value, const ValuePredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
Label(const Output<Node>& value, const SymbolPredicate pred, const OutputVector& wrapped_values)
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
Label(const Output<Node>& value, const SymbolPredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}

Label(const Output<Node>& value, const NodePredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
Expand Down Expand Up @@ -107,6 +120,8 @@ std::shared_ptr<Node> any_input();

OPENVINO_API
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
OPENVINO_API
std::shared_ptr<Node> any_input(const pattern::op::SymbolPredicate& pred);
} // namespace pattern
} // namespace pass
} // namespace ov
Loading

0 comments on commit e4df58f

Please sign in to comment.