diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp index bc1a64e13364c4..6708eaef7d3943 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp @@ -89,7 +89,7 @@ static void reg_pattern_wrap_type(py::module m) { wrap_type.def(py::init([](const std::string& type_name, const ov::Output& input) { return std::make_shared(get_type(type_name), - nullptr, + ov::pass::pattern::op::Predicate(), ov::OutputVector{input}); }), py::arg("type_name"), @@ -106,7 +106,7 @@ static void reg_pattern_wrap_type(py::module m) { wrap_type.def(py::init([](const std::string& type_name, const std::shared_ptr& input) { return std::make_shared(get_type(type_name), - nullptr, + ov::pass::pattern::op::Predicate(), ov::OutputVector{input}); }), py::arg("type_name"), @@ -165,7 +165,9 @@ static void reg_pattern_wrap_type(py::module m) { )"); wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs) { - return std::make_shared(get_type(type_name), nullptr, inputs); + return std::make_shared(get_type(type_name), + ov::pass::pattern::op::Predicate(), + inputs); }), py::arg("type_name"), py::arg("inputs"), @@ -181,7 +183,7 @@ static void reg_pattern_wrap_type(py::module m) { wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs) { return std::make_shared(get_type(type_name), - nullptr, + ov::pass::pattern::op::Predicate(), ov::as_output_vector(inputs)); }), py::arg("type_name"), @@ -264,7 +266,7 @@ static void reg_pattern_wrap_type(py::module m) { wrap_type.def(py::init([](const std::vector& type_names, const ov::Output& input) { return std::make_shared(get_types(type_names), - nullptr, + ov::pass::pattern::op::Predicate(), ov::OutputVector{input}); }), py::arg("type_names"), @@ -281,7 +283,7 @@ static void reg_pattern_wrap_type(py::module m) { wrap_type.def(py::init([](const std::vector& type_names, const std::shared_ptr& input) { return std::make_shared(get_types(type_names), - nullptr, + ov::pass::pattern::op::Predicate(), ov::OutputVector{input}); }), py::arg("type_names"), @@ -343,7 +345,9 @@ static void reg_pattern_wrap_type(py::module m) { )"); wrap_type.def(py::init([](const std::vector& type_names, const ov::OutputVector& inputs) { - return std::make_shared(get_types(type_names), nullptr, inputs); + return std::make_shared(get_types(type_names), + ov::pass::pattern::op::Predicate(), + inputs); }), py::arg("type_names"), py::arg("inputs"), @@ -359,7 +363,7 @@ static void reg_pattern_wrap_type(py::module m) { wrap_type.def(py::init([](const std::vector& type_names, const ov::NodeVector& inputs) { return std::make_shared(get_types(type_names), - nullptr, + ov::pass::pattern::op::Predicate(), ov::as_output_vector(inputs)); }), py::arg("type_names"), @@ -501,8 +505,7 @@ static void reg_pattern_optional(py::module m) { optional_type.def(py::init([](const std::vector& type_names, const ov::Output& input) { return std::make_shared(get_types(type_names), - ov::OutputVector{input}, - nullptr); + ov::OutputVector{input}); }), py::arg("type_names"), py::arg("input"), @@ -518,8 +521,7 @@ static void reg_pattern_optional(py::module m) { optional_type.def(py::init([](const std::vector& type_names, const std::shared_ptr& input) { return std::make_shared(get_types(type_names), - ov::OutputVector{input}, - nullptr); + ov::OutputVector{input}); }), py::arg("type_names"), py::arg("input"), @@ -533,13 +535,12 @@ static void reg_pattern_optional(py::module m) { :type input: openvino.runtime.Node )"); - optional_type.def( - py::init([](const std::vector& type_names, const ov::OutputVector& inputs) { - return std::make_shared(get_types(type_names), inputs, nullptr); - }), - py::arg("type_names"), - py::arg("inputs"), - R"( + optional_type.def(py::init([](const std::vector& type_names, const ov::OutputVector& inputs) { + return std::make_shared(get_types(type_names), inputs); + }), + py::arg("type_names"), + py::arg("inputs"), + R"( Create Optional with the given node type and input node. :param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"] @@ -551,8 +552,7 @@ static void reg_pattern_optional(py::module m) { optional_type.def(py::init([](const std::vector& type_names, const ov::NodeVector& inputs) { return std::make_shared(get_types(type_names), - ov::as_output_vector(inputs), - nullptr); + ov::as_output_vector(inputs)); }), py::arg("type_names"), py::arg("inputs"), diff --git a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp index 51177738c1e2d5..d0bc68c7e054cf 100644 --- a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp @@ -5,6 +5,7 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" +#include "transformations/symbolic_transformations/symbolic_optimizations.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -88,29 +89,42 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass { * @ingroup ov_transformation_common_api * @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation */ -class ov::pass::RoPEFusion : public ov::pass::GraphRewrite { + +class ov::pass::RoPEFusion : public ov::pass::ModelPass { public: - OPENVINO_GRAPH_REWRITE_RTTI("RoPEFusion"); - RoPEFusion(bool support_2d_rope = false) { - add_matcher(); - add_matcher(); - add_matcher(); + OPENVINO_MODEL_PASS_RTTI("RoPEFusion"); + + explicit RoPEFusion(bool support_2d_rope = false) : support_2d_rope(support_2d_rope){}; + + bool run_on_model(const std::shared_ptr& m) override { + auto symbolic_pipeline = ov::pass::SymbolicOptimizations(false); + auto rope_fusions = symbolic_pipeline.get_manager()->register_pass(); + rope_fusions->set_name("RoPEFusions"); + + rope_fusions->add_matcher(); + rope_fusions->add_matcher(); + rope_fusions->add_matcher(); // optional heads & tails are fused in separate matcher pass, // after RoPENode has been created. - add_matcher(); - add_matcher(); - add_matcher(); + rope_fusions->add_matcher(); + rope_fusions->add_matcher(); + rope_fusions->add_matcher(); - add_matcher(0); - add_matcher(1); + rope_fusions->add_matcher(0); + rope_fusions->add_matcher(1); if (support_2d_rope) { - add_matcher(0, true); - add_matcher(1, true); + rope_fusions->add_matcher(0, true); + rope_fusions->add_matcher(1, true); } - add_matcher(0); - add_matcher(1); + rope_fusions->add_matcher(0); + rope_fusions->add_matcher(1); + + rope_fusions->add_matcher(); - add_matcher(); + return symbolic_pipeline.run_on_model(m); } + +protected: + bool support_2d_rope = false; }; diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp index d6629f326a2a70..7c9a8a93549232 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp @@ -172,7 +172,7 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() { } ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) { - m_manager = std::make_shared("Symbolic"); + m_manager = std::make_shared(get_pass_config(), "Symbolic"); m_manager->set_per_pass_validation(false); #define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass(__VA_ARGS__); @@ -208,7 +208,7 @@ bool ov::pass::SymbolicOptimizations::run_on_model(const std::shared_ptrdisable(); pass_config->disable(); - m_manager->run_passes(m); + bool status = m_manager->run_passes(m); ov::remove_skip_invalidation_rti(m); - return true; + return status; } diff --git a/src/core/include/openvino/pass/pattern/matcher.hpp b/src/core/include/openvino/pass/pattern/matcher.hpp index 7112ac9ff85e64..13249f471d3601 100644 --- a/src/core/include/openvino/pass/pattern/matcher.hpp +++ b/src/core/include/openvino/pass/pattern/matcher.hpp @@ -164,6 +164,9 @@ class OPENVINO_API Matcher { PatternValueMap& get_pattern_value_map() { return m_pattern_map; } + PatternSymbolMap& get_symbols() { + return m_pattern_symbols; + } PatternValueMaps& get_pattern_value_maps() { return m_pattern_value_maps; } @@ -198,6 +201,7 @@ class OPENVINO_API Matcher { Output m_match_root; Output m_pattern_node; PatternValueMap m_pattern_map; + PatternSymbolMap m_pattern_symbols; PatternValueMaps m_pattern_value_maps; OutputVector m_matched_list; diff --git a/src/core/include/openvino/pass/pattern/op/any.hpp b/src/core/include/openvino/pass/pattern/op/any.hpp index 65c2df8ecd87bc..f98cb28957ecfe 100644 --- a/src/core/include/openvino/pass/pattern/op/any.hpp +++ b/src/core/include/openvino/pass/pattern/op/any.hpp @@ -18,21 +18,27 @@ class OPENVINO_API Any : public Pattern { OPENVINO_RTTI("patternAny"); /// \brief creates a Any node containing a sub-pattern described by \sa type and \sa /// shape. - Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values) + Any(const element::Type& type, const PartialShape& s, Predicate pred, const OutputVector& wrapped_values) : Pattern(wrapped_values, pred) { set_output_type(0, type, s); } + Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values) + : Any(type, s, Predicate(pred), wrapped_values) {} + Any(const element::Type& type, const PartialShape& s, SymbolPredicate pred, const OutputVector& wrapped_values) + : Any(type, s, Predicate(pred), wrapped_values) {} Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values) - : Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {} + : Any(type, s, pred, as_output_vector(wrapped_values)) {} + /// \brief creates a Any node containing a sub-pattern described by the type and /// shape of \sa node. - Any(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) + Any(const Output& node, Predicate pred, const OutputVector& wrapped_values) : Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} + Any(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) + : Any(node, Predicate(pred), wrapped_values) {} + Any(const Output& node, SymbolPredicate pred, const OutputVector& wrapped_values) + : Any(node, Predicate(pred), wrapped_values) {} Any(const Output& node, NodePredicate pred, const NodeVector& wrapped_values) - : Any(node.get_element_type(), - node.get_partial_shape(), - as_value_predicate(pred), - as_output_vector(wrapped_values)) {} + : Any(node, pred, as_output_vector(wrapped_values)) {} bool match_value(pattern::Matcher* matcher, const Output& pattern_value, diff --git a/src/core/include/openvino/pass/pattern/op/any_of.hpp b/src/core/include/openvino/pass/pattern/op/any_of.hpp index 7e011dc777c730..396e30eb241024 100644 --- a/src/core/include/openvino/pass/pattern/op/any_of.hpp +++ b/src/core/include/openvino/pass/pattern/op/any_of.hpp @@ -30,6 +30,13 @@ class OPENVINO_API AnyOf : public Pattern { } set_output_type(0, type, s); } + AnyOf(const element::Type& type, const PartialShape& s, SymbolPredicate pred, const OutputVector& wrapped_values) + : Pattern(wrapped_values, pred) { + if (wrapped_values.size() != 1) { + OPENVINO_THROW("AnyOf expects exactly one argument"); + } + set_output_type(0, type, s); + } AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values) : AnyOf( type, @@ -43,6 +50,8 @@ class OPENVINO_API AnyOf : public Pattern { /// shape of \sa node. AnyOf(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) : AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} + AnyOf(const Output& node, SymbolPredicate pred, const OutputVector& wrapped_values) + : AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} AnyOf(const std::shared_ptr& node, NodePredicate pred, const NodeVector& wrapped_values) : AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {} bool match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) override; diff --git a/src/core/include/openvino/pass/pattern/op/label.hpp b/src/core/include/openvino/pass/pattern/op/label.hpp index bbfa626abc0a76..257d7295ae0c7f 100644 --- a/src/core/include/openvino/pass/pattern/op/label.hpp +++ b/src/core/include/openvino/pass/pattern/op/label.hpp @@ -43,7 +43,13 @@ class OPENVINO_API Label : public Pattern { : Pattern(OutputVector{wrap_values(wrapped_values)}, pred) { set_output_type(0, type, s); } - + Label(const element::Type& type, + const PartialShape& s, + const SymbolPredicate pred, + const OutputVector& wrapped_values) + : Pattern(OutputVector{wrap_values(wrapped_values)}, pred) { + set_output_type(0, type, s); + } explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic()) : Label( type, @@ -56,6 +62,9 @@ class OPENVINO_API Label : public Pattern { Label(const element::Type& type, const PartialShape& s, ValuePredicate pred) : Label(type, s, std::move(pred), OutputVector{}) {} + Label(const element::Type& type, const PartialShape& s, SymbolPredicate pred) + : Label(type, s, std::move(pred), OutputVector{}) {} + Label(const element::Type& type, const PartialShape& s, NodePredicate pred) : Label(type, s, as_value_predicate(std::move(pred)), OutputVector{}) {} @@ -78,6 +87,10 @@ class OPENVINO_API Label : public Pattern { : Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {} Label(const Output& value, const ValuePredicate pred) : Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {} + Label(const Output& value, const SymbolPredicate pred, const OutputVector& wrapped_values) + : Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {} + Label(const Output& value, const SymbolPredicate pred) + : Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {} Label(const Output& value, const NodePredicate pred) : Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {} @@ -107,6 +120,8 @@ std::shared_ptr any_input(); OPENVINO_API std::shared_ptr any_input(const pattern::op::ValuePredicate& pred); +OPENVINO_API +std::shared_ptr any_input(const pattern::op::SymbolPredicate& pred); } // namespace pattern } // namespace pass } // namespace ov diff --git a/src/core/include/openvino/pass/pattern/op/optional.hpp b/src/core/include/openvino/pass/pattern/op/optional.hpp index 36a64fe1c4d993..80043254f779a7 100644 --- a/src/core/include/openvino/pass/pattern/op/optional.hpp +++ b/src/core/include/openvino/pass/pattern/op/optional.hpp @@ -53,13 +53,9 @@ class OPENVINO_API Optional : public Pattern { /// \param type_infos Optional operation types to exclude them from the matching /// in case the following op types do not exist in a pattern to match. /// \param patterns The pattern to match a graph. - Optional( - const std::vector& type_infos, - const OutputVector& inputs = {}, - const pattern::op::ValuePredicate& pred = - [](const Output& output) { - return true; - }) + Optional(const std::vector& type_infos, + const OutputVector& inputs = {}, + const pattern::op::Predicate& pred = Predicate()) : Pattern(inputs, pred), optional_types(type_infos){}; @@ -88,22 +84,56 @@ void collect_type_info(std::vector& type_info_vec) { } template -std::shared_ptr optional(const OutputVector& inputs, const pattern::op::ValuePredicate& pred = nullptr) { +std::shared_ptr optional(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) { std::vector optional_type_info_vec; collect_type_info(optional_type_info_vec); return std::make_shared(optional_type_info_vec, inputs, pred); } template -std::shared_ptr optional(const Output& input, const pattern::op::ValuePredicate& pred = nullptr) { +std::shared_ptr optional(const Output& input, const pattern::op::ValuePredicate& pred) { return optional(OutputVector{input}, pred); } template -std::shared_ptr optional(const pattern::op::ValuePredicate& pred = nullptr) { +std::shared_ptr optional(const pattern::op::ValuePredicate& pred) { return optional(OutputVector{}, pred); } +template +std::shared_ptr optional(const OutputVector& inputs, const pattern::op::SymbolPredicate& pred) { + std::vector optional_type_info_vec; + collect_type_info(optional_type_info_vec); + return std::make_shared(optional_type_info_vec, inputs, pred); +} + +template +std::shared_ptr optional(const Output& input, const pattern::op::SymbolPredicate& pred) { + return optional(OutputVector{input}, pred); +} + +template +std::shared_ptr optional(const pattern::op::SymbolPredicate& pred) { + return optional(OutputVector{}, pred); +} + +template +std::shared_ptr optional(const OutputVector& inputs) { + std::vector optional_type_info_vec; + collect_type_info(optional_type_info_vec); + return std::make_shared(optional_type_info_vec, inputs, pattern::op::Predicate()); +} + +template +std::shared_ptr optional(const Output& input) { + return optional(OutputVector{input}); +} + +template +std::shared_ptr optional() { + return optional(OutputVector{}); +} + } // namespace pattern } // namespace pass } // namespace ov diff --git a/src/core/include/openvino/pass/pattern/op/pattern.hpp b/src/core/include/openvino/pass/pattern/op/pattern.hpp index 08afddd43a1c09..b0348ddc63653f 100644 --- a/src/core/include/openvino/pass/pattern/op/pattern.hpp +++ b/src/core/include/openvino/pass/pattern/op/pattern.hpp @@ -24,6 +24,48 @@ using PatternValueMaps = std::vector; using PatternMap = std::map, std::shared_ptr>; +class PatternSymbolValue { +public: + PatternSymbolValue() = default; + PatternSymbolValue(const std::shared_ptr& s) : m_value(s){}; + PatternSymbolValue(const int64_t& i) : m_value(i){}; + PatternSymbolValue(const double& d) : m_value(d){}; + + bool is_dynamic() const { + return is_valid() && m_value.is>(); + } + + bool is_static() const { + return !is_dynamic(); + } + + bool is_integer() const { + return is_valid() && m_value.is(); + } + + bool is_double() const { + return is_valid() && m_value.is(); + } + + int64_t i() const { + return m_value.as(); + } + double d() const { + return m_value.as(); + } + std::shared_ptr s() const { + return m_value.as>(); + } + +private: + bool is_valid() const { + return m_value != nullptr && + m_value.is() ^ m_value.is() ^ m_value.is>(); + } + ov::Any m_value; +}; +using PatternSymbolMap = std::unordered_map; + PatternMap as_pattern_map(const PatternValueMap& pattern_value_map); PatternValueMap as_pattern_value_map(const PatternMap& pattern_map); @@ -74,32 +116,68 @@ std::function)> type_matches_any(const std::vector)> all_of(const std::vector)>>& predicates); +OPENVINO_API +std::function&)> shape_matches( + const std::string& shape_notation); + namespace op { using NodePredicate = std::function)>; using ValuePredicate = std::function& value)>; +using SymbolPredicate = std::function&)>; OPENVINO_API ValuePredicate as_value_predicate(NodePredicate pred); +namespace { +constexpr bool symbol_true_predicate(pass::pattern::PatternSymbolMap&, const Output&) { + return true; +} +} // namespace + +class OPENVINO_API Predicate { +public: + Predicate() : m_pred(symbol_true_predicate) {} + Predicate(SymbolPredicate predicate) : m_pred(predicate) {} + Predicate(ValuePredicate predicate) { + m_pred = [=](pass::pattern::PatternSymbolMap&, const Output& output) { + return predicate(output); + }; + } + Predicate(NodePredicate predicate) { + m_pred = [=](pass::pattern::PatternSymbolMap&, const Output& output) { + return predicate(output.get_node_shared_ptr()); + }; + } + Predicate(std::function)> predicate) { + m_pred = [=](pass::pattern::PatternSymbolMap&, const Output& output) { + return predicate(output); + }; + } + + bool operator()(pass::pattern::PatternSymbolMap& m, const Output& output) const { + return m_pred(m, output); + } + +private: + SymbolPredicate m_pred; +}; + class OPENVINO_API Pattern : public Node { public: /// \brief \p a base class for \sa Skip and \sa Label /// - Pattern(const OutputVector& patterns, ValuePredicate pred); - - Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {} + Pattern(const OutputVector& patterns, Predicate pred); + Pattern(const OutputVector& patterns); std::shared_ptr clone_with_new_inputs(const OutputVector& /* new_args */) const override { OPENVINO_THROW("Uncopyable"); } - ValuePredicate get_predicate() const; - std::ostream& write_description(std::ostream& out, uint32_t depth) const override; virtual std::ostream& write_type_description(std::ostream& out) const; protected: - ValuePredicate m_predicate; + Predicate m_predicate; }; } // namespace op } // namespace pattern diff --git a/src/core/include/openvino/pass/pattern/op/wrap_type.hpp b/src/core/include/openvino/pass/pattern/op/wrap_type.hpp index 75ee41ffa3753e..a02b3e731f729d 100644 --- a/src/core/include/openvino/pass/pattern/op/wrap_type.hpp +++ b/src/core/include/openvino/pass/pattern/op/wrap_type.hpp @@ -15,25 +15,17 @@ class OPENVINO_API WrapType : public Pattern { public: OPENVINO_RTTI("patternAnyType"); - explicit WrapType( - NodeTypeInfo wrapped_type, - const ValuePredicate& pred = - [](const Output& output) { - return true; - }, - const OutputVector& input_values = {}) + explicit WrapType(NodeTypeInfo wrapped_type, + const Predicate& pred = Predicate(), + const OutputVector& input_values = {}) : Pattern(input_values, pred), m_wrapped_types({wrapped_type}) { set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); } - explicit WrapType( - std::vector wrapped_types, - const ValuePredicate& pred = - [](const Output& output) { - return true; - }, - const OutputVector& input_values = {}) + explicit WrapType(std::vector wrapped_types, + const Predicate& pred = Predicate(), + const OutputVector& input_values = {}) : Pattern(input_values, pred), m_wrapped_types(std::move(wrapped_types)) { set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); @@ -71,17 +63,27 @@ std::shared_ptr wrap_type(const OutputVector& inputs, const pattern::op::V return std::make_shared(info, pred, inputs); } +template +std::shared_ptr wrap_type(const OutputVector& inputs, const pattern::op::SymbolPredicate& pred) { + std::vector info; + collect_wrap_info(info); + return std::make_shared(info, pred, inputs); +} + template std::shared_ptr wrap_type(const OutputVector& inputs = {}) { - return wrap_type(inputs, [](const Output& output) { - return true; - }); + return wrap_type(inputs, pattern::op::Predicate()); } template std::shared_ptr wrap_type(const pattern::op::ValuePredicate& pred) { return wrap_type({}, pred); } + +template +std::shared_ptr wrap_type(const pattern::op::SymbolPredicate& pred) { + return wrap_type({}, pred); +} } // namespace pattern } // namespace pass } // namespace ov diff --git a/src/core/src/pattern/op/any.cpp b/src/core/src/pattern/op/any.cpp index 64735104ee8c53..9be91444e11cef 100644 --- a/src/core/src/pattern/op/any.cpp +++ b/src/core/src/pattern/op/any.cpp @@ -10,6 +10,6 @@ bool ov::pass::pattern::op::Any::match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) { matcher->add_node(graph_value); - return m_predicate(graph_value) && + return m_predicate(matcher->get_symbols(), graph_value) && matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr()); } diff --git a/src/core/src/pattern/op/any_of.cpp b/src/core/src/pattern/op/any_of.cpp index 22688eb1434946..724c45a1846906 100644 --- a/src/core/src/pattern/op/any_of.cpp +++ b/src/core/src/pattern/op/any_of.cpp @@ -10,7 +10,7 @@ bool ov::pass::pattern::op::AnyOf::match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) { matcher->add_node(graph_value); - return m_predicate(graph_value) && ([&]() { + return m_predicate(matcher->get_symbols(), graph_value) && ([&]() { for (const auto& arg : graph_value.get_node_shared_ptr()->input_values()) { auto saved = matcher->start_match(); if (matcher->match_value(input_value(0), arg)) { diff --git a/src/core/src/pattern/op/label.cpp b/src/core/src/pattern/op/label.cpp index 1bf916d32ecdb8..75604b46da3a7b 100644 --- a/src/core/src/pattern/op/label.cpp +++ b/src/core/src/pattern/op/label.cpp @@ -22,7 +22,7 @@ ov::Output ov::pass::pattern::op::Label::wrap_values(const ov::OutputV bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher, const ov::Output& pattern_value, const ov::Output& graph_value) { - if (m_predicate(graph_value)) { + if (m_predicate(matcher->get_symbols(), graph_value)) { auto& pattern_map = matcher->get_pattern_value_map(); auto saved = matcher->start_match(); matcher->add_node(graph_value); @@ -43,3 +43,7 @@ std::shared_ptr ov::pass::pattern::any_input() { std::shared_ptr ov::pass::pattern::any_input(const ov::pass::pattern::op::ValuePredicate& pred) { return std::make_shared(element::dynamic, PartialShape::dynamic(), pred); } + +std::shared_ptr ov::pass::pattern::any_input(const ov::pass::pattern::op::SymbolPredicate& pred) { + return std::make_shared(element::dynamic, PartialShape::dynamic(), pred); +} diff --git a/src/core/src/pattern/op/pattern.cpp b/src/core/src/pattern/op/pattern.cpp index 826c6ff532ffb6..3e3865b9848c97 100644 --- a/src/core/src/pattern/op/pattern.cpp +++ b/src/core/src/pattern/op/pattern.cpp @@ -12,9 +12,21 @@ namespace pass { namespace pattern { namespace op { namespace { -constexpr bool node_value_true_predicate(const Output&) { +constexpr bool value_true_predicate(const Output&) { return true; } + +SymbolPredicate as_symbol_predicate(const ValuePredicate& pred) { + return [=](PatternSymbolMap&, const Output& out) { + return pred(out); + }; +} + +SymbolPredicate as_symbol_predicate(const NodePredicate& pred) { + return [=](PatternSymbolMap&, const Output& out) { + return pred(out.get_node_shared_ptr()); + }; +} } // namespace struct NodeValuePredicate { @@ -25,20 +37,15 @@ struct NodeValuePredicate { NodePredicate pred; }; -Pattern::Pattern(const OutputVector& patterns, ValuePredicate pred) - : Node(patterns), - m_predicate(pred ? std::move(pred) : node_value_true_predicate) {} +Pattern::Pattern(const OutputVector& patterns) : Node(patterns), m_predicate() {} -// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM -ValuePredicate Pattern::get_predicate() const { - return m_predicate; -} +Pattern::Pattern(const OutputVector& patterns, Predicate pred) : Node(patterns), m_predicate(pred) {} ValuePredicate as_value_predicate(NodePredicate pred) { if (pred) { return NodeValuePredicate{std::move(pred)}; } else { - return node_value_true_predicate; + return value_true_predicate; } } @@ -164,6 +171,128 @@ std::function)> all_of(const std::vector>, int64_t> parse_string(const std::string& s_global) { + auto s = s_global; + s.erase(remove_if(s.begin(), s.end(), isspace), s.end()); + const std::set scalar = {"", "[]", "()"}, dynamic = {"?", "..."}; + if (scalar.count(s)) + return {{}, 0}; + if (dynamic.count(s)) + return {{}, -1}; + const std::set brackets_to_remove = {'[', ']', '(', ')'}; + s.erase(remove_if(s.begin(), + s.end(), + [&](char c) { + return brackets_to_remove.count(c); + }), + s.end()); + + std::vector parsed; + size_t pos = 0, pos_next; + std::string token; + while ((pos_next = s.find(',', pos)) != std::string::npos) { + token = s.substr(pos, pos_next - pos); + parsed.push_back(token); + pos = pos_next + 1; + } + // collect whole string if no delimiter is found + token = s.substr(pos, pos_next); + parsed.push_back(token); + + std::vector> idx_to_name; + + bool ellipsis_visited = false; + for (int64_t i = 0; i < parsed.size(); ++i) { + auto dimension = parsed[i]; + if (dimension == "...") { + OPENVINO_ASSERT(!ellipsis_visited, "Only one ellipsis is allowed for symbolic predicate notation"); + ellipsis_visited = true; + continue; + } + idx_to_name.emplace_back((ellipsis_visited ? static_cast(parsed.size()) - i : i), dimension); + } + return {idx_to_name, (ellipsis_visited ? idx_to_name.size() : -2)}; +} + +std::pair str2int(const std::string& str) { + auto s = str.c_str(); + char* end; + int64_t l; + l = strtol(s, &end, 10); + if (*s == '\0' || *end != '\0') + return {1, 0}; + return {0, l}; +} +} // namespace + +pass::pattern::op::SymbolPredicate shape_matches(const std::string& shape_notation) { + /* Shape Notation Rules and Examples: + * Dimension variants: + * - digit -- for static dimension + * - string_name -- for static or symbolic equivalence check / recording; no spaces or commas in the name + * - question mark -- for irrelevant dimensions that don't need recording or checking. Relevant for rank check. + * - ellipsis -- any number of dimensions, including no dimensions. Only one ellipsis is + * Shape may or may not be enclosed with brackets -- [] or (). Dimensions are delimited with commas. + * Spaces are irrelevant. + * + * Examples: + * "[A, 3, C, D]" -- check for rank == 4; A, C, D checked / recorded the Match obj; static dim checked; + * "[A, 3, ..., D]" -- no rank restrictions; A, D checked / recorded the Match obj; static dim checked; + * "[?, D]" -- check for rank == 2; D checked / recorded the Match obj; ? dim -- not checked and not recorded; + * "[Batch, SequenceLength, *]" -- check for rank == 3; Batch, SequenceLength checked / recorded the Match obj; + * */ + const auto& parsed = parse_string(shape_notation); + const auto& idx_to_name = parsed.first; + const auto& rank_restrictions = parsed.second; + + return [=](pass::pattern::PatternSymbolMap& m, const Output& output) -> bool { + const auto& shape = output.get_partial_shape(); + if (rank_restrictions == 0) // scalar + return shape.is_static() && shape.size() == 0; + if (rank_restrictions == -1) // fully dynamic + return shape.rank().is_dynamic(); + if (rank_restrictions == -2 && (shape.rank().is_dynamic() || shape.size() < idx_to_name.size())) + // minimum rank check + return false; + if (rank_restrictions > 0 && (shape.rank().is_dynamic() || shape.size() != rank_restrictions)) + return false; + for (const auto& item : idx_to_name) { + const auto& name = item.second; + if (name == "?") + continue; + const auto& this_dim = shape[item.first]; + auto int_from_str = str2int(name); + if (int_from_str.first) { // failed the conversion -- this is a name + if (m.count(name)) { + const auto& recorded_value = m.at(name); + if (recorded_value.is_dynamic()) { + const auto& recorded_symbol = recorded_value.s(); + if (!ov::symbol::are_equal(recorded_symbol, this_dim.get_symbol())) + return false; + } else if (recorded_value.is_integer()) { + if (this_dim.is_dynamic() || this_dim.get_length() != recorded_value.i()) + return false; + } else { + return false; + } + } else { + if (this_dim.is_static()) + m[name] = {static_cast(this_dim.get_length())}; + else if (auto symbol = this_dim.get_symbol()) + m[name] = {symbol}; + else + return false; + } + } else { // this_dim is not a name, but an integer + if (this_dim.is_dynamic() || this_dim.get_length() != int_from_str.second) + return false; + } + } + return true; + }; +} } // namespace pattern } // namespace pass pass::pattern::op::ValuePredicate operator||(const pass::pattern::op::ValuePredicate& a, diff --git a/src/core/src/pattern/op/wrap_type.cpp b/src/core/src/pattern/op/wrap_type.cpp index aca354f509077c..99ddded328b896 100644 --- a/src/core/src/pattern/op/wrap_type.cpp +++ b/src/core/src/pattern/op/wrap_type.cpp @@ -15,7 +15,7 @@ bool ov::pass::pattern::op::WrapType::match_value(Matcher* matcher, [&](const NodeTypeInfo& type_info) { return graph_value.get_node_shared_ptr()->get_type_info().is_castable(type_info); }) && - m_predicate(graph_value)) { + m_predicate(matcher->get_symbols(), graph_value)) { auto& pattern_map = matcher->get_pattern_value_map(); pattern_map[shared_from_this()] = graph_value; matcher->add_node(graph_value); diff --git a/src/core/tests/pattern.cpp b/src/core/tests/pattern.cpp index 39d30c3fd27492..bc49751e152cbd 100644 --- a/src/core/tests/pattern.cpp +++ b/src/core/tests/pattern.cpp @@ -1290,3 +1290,7 @@ TEST(pattern, pattern_predicate_operator) { ov::pass::pattern::rank_equals(2)), model_add)); } + +TEST(pattern, pattern_symbol_predicate) { + ov::pass::pattern::any_input(ov::pass::pattern::shape_matches("A,B,?")); +} \ No newline at end of file