From 043bc89d2b6eda6a62866fb4cfe79a4fb9dfa2a2 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 3 Dec 2024 12:25:49 +0100 Subject: [PATCH] [PT FE] Support dynamic shapes torch.export Signed-off-by: Maxim Vafin --- .../openvino/frontend/pytorch/fx_decoder.py | 374 +++++++++++------- .../openvino/frontend/pytorch/ts_decoder.py | 6 +- .../pyopenvino/frontend/pytorch/decoder.hpp | 8 +- .../openvino/frontend/pytorch/decoder.hpp | 7 +- .../frontend/pytorch/node_context.hpp | 39 +- src/frontends/pytorch/src/node_context.cpp | 86 ++-- src/frontends/pytorch/src/op/as_strided.cpp | 77 ++-- src/frontends/pytorch/src/op/cat.cpp | 15 +- src/frontends/pytorch/src/op/expand.cpp | 18 - src/frontends/pytorch/src/op/full.cpp | 11 +- src/frontends/pytorch/src/op/index.cpp | 19 +- src/frontends/pytorch/src/op/repeat.cpp | 25 ++ .../pytorch/src/op/repeat_interleave.cpp | 2 +- src/frontends/pytorch/src/op/reshape.cpp | 24 -- src/frontends/pytorch/src/op_table.cpp | 12 +- .../pytorch/src/translate_session.hpp | 3 +- src/frontends/pytorch/src/utils.cpp | 24 -- src/frontends/pytorch/src/utils.hpp | 8 +- tests/model_hub_tests/pytorch/test_timm.py | 9 +- tests/model_hub_tests/pytorch/timm_models | 24 +- tests/model_hub_tests/pytorch/torch_utils.py | 12 +- 21 files changed, 416 insertions(+), 387 deletions(-) create mode 100644 src/frontends/pytorch/src/op/repeat.cpp diff --git a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py index c448571f1ac17a..e3e2661bf8f136 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py @@ -5,31 +5,167 @@ # mypy: ignore-errors import logging -import torch import inspect +import torch from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType from openvino.runtime import PartialShape, Type as OVType, OVAny, Shape -from openvino.frontend.pytorch.utils import make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const +from openvino.frontend.pytorch.utils import ( + make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const) logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) -class InlinedInput: - def __init__(self, data) -> None: - self.data = data +class BaseFXDecoder (Decoder): + """ + BaseFXDecoder is a class that extends the Decoder class to handle decoding + operations for FX graphs in PyTorch. It provides a common interface for all + FX decoders. + """ + + def __init__(self, mark_node_callback=None) -> None: + Decoder.__init__(self) + self.mark_node_callback = mark_node_callback + # We store every decoder created by this decoder so that + # all them are not deleted until the first decoder is deleted + self.m_decoders = [] + self._inputs = [] + self._outputs = [] + + @staticmethod + def unpack_containers(arg): + if isinstance(arg, (tuple, list)): + res = [] + for e in arg: + res.extend(BaseFXDecoder.unpack_containers(e)) + return res + elif isinstance(arg, dict): + res = [] + for k, e in arg.items(): + unpacked = BaseFXDecoder.unpack_containers(e) + if len(unpacked) == 1: + unpacked[0] = (k, unpacked[0][1]) + res.extend(unpacked) + return res + else: + return [("", arg)] + + @staticmethod + def arg_to_constant(arg): + if isinstance(arg, list): + if len(arg) > 0: + return make_constant(pt_to_ov_type_map[type( + arg[0]).__name__], Shape([len(arg)]), arg) + else: + # TODO: which type should we use if list is empty? Need a signaling value here + return make_constant(OVType.i32, Shape([0]), []) + elif isinstance(arg, bool): + return make_constant(OVType.boolean, Shape([]), [arg]) + elif isinstance(arg, int): + return make_constant(OVType.i64, Shape([]), [arg]) + elif isinstance(arg, float): + return make_constant(OVType.f32, Shape([]), [arg]) + elif isinstance(arg, str): + u8_tensor = torch.frombuffer(str.encode(arg), dtype=torch.uint8) + return torch_tensor_to_ov_const(u8_tensor, shared_memory=True) + return None + + @staticmethod + def get_type_for_value(value): + if issubclass(type(value), torch.fx.Node): + if ('tensor_meta' in value.meta.keys()): + if value.meta['tensor_meta'] and isinstance(value.meta['tensor_meta'], torch.Tensor): + pt_type = value.meta['tensor_meta'].dtype + if str(pt_type) in pt_to_ov_type_map: + ov_type = pt_to_ov_type_map[str(pt_type)] + return OVAny(ov_type) + return OVAny(OVType.dynamic) + elif isinstance(value, int): + return OVAny(DecoderType.PyScalar(OVAny(OVType.i64))) + elif isinstance(value, float): + return OVAny(DecoderType.PyScalar(OVAny(OVType.f32))) + elif isinstance(value, bool): + return OVAny(DecoderType.PyScalar(OVAny(OVType.boolean))) + elif isinstance(value, list): + return OVAny(DecoderType.List(BaseFXDecoder.get_type_for_value(value[0]))) + return OVAny(OVType.dynamic) + def inputs(self): + # Consider 0 a special case which may mean the input is inlined, but not guaranteed + return [x if not isinstance(x, InlinedInput) else 0 for x in self._inputs] + + def input(self, index): + return self.inputs()[index] + + def output(self, index): + return self.outputs()[index] + + def get_input_debug_name(self, index): + return "input" + str(index) + + def is_input_inlined(self, index): + return isinstance(self._inputs[index], InlinedInput) + + def get_inlined_input_decoder(self, index): + target = self._inputs[index] + assert isinstance(target, InlinedInput), "Requested non-inlined input" + in_decoder = InlinedInputDecoder( + target, self._nodes, self.mark_node_callback) + self.m_decoders.append(in_decoder) + return in_decoder + + def get_input_shape(self, index): + return PartialShape.dynamic() -class TorchFXPythonDecoder (Decoder): + def get_input_type(self, index): + return OVAny(OVType.dynamic) + + def get_output_type(self, index): + return OVAny(OVType.dynamic) + + def input_is_none(self, index): + if index < len(self._inputs) and isinstance(self._inputs[index], InlinedInput): + return self._inputs[index].data is None + return False + + def decoder_type_name(self) -> str: + return "fx" + + def get_schema(self): + return 'NONE' + + def mark_node(self, node): + if self.mark_node_callback is not None: + self.mark_node_callback(self, node) + return node + + def get_subgraphs(self): + return [] + + def get_subgraph_size(self): + return len(self.get_subgraphs()) + + def as_string(self): + return None + + def may_produce_alias(self, in_index: int, out_index: int) -> bool: + return False + + def get_rt_info(self): + rt_info = {} + return rt_info + + +class TorchFXPythonDecoder (BaseFXDecoder): + """ + Decoder for PyTorch FX GraphModule and Node objects to OpenVINO IR. + """ def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False): - Decoder.__init__(self) - self.mark_node_callback = mark_node_callback - # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted - self.m_decoders = [] + super().__init__(mark_node_callback) self.pt_module = pt_module self.fx_gm = fx_gm if fx_gm is not None else pt_module self.input_types = [OVAny(pt_to_ov_type_map[str(t)]) @@ -40,17 +176,13 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, self._example_input = None if issubclass(type(pt_module), torch.fx.graph_module.GraphModule): - self._input_is_list = None self._nodes = list(pt_module.graph.nodes) - self._inputs = [] - self._outputs = [] found_types = [] found_shapes = [] - for i in range(len(self._nodes)): - if self._nodes[i].op == 'placeholder': + for i, value in enumerate(self._nodes): + if value.op == 'placeholder': self._inputs.append(i) - value = self._nodes[i] self._input_signature.append(value.name) if hasattr(value, "meta") and ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']: found_shapes.append(value.meta['tensor_meta'].shape) @@ -59,19 +191,19 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, else: found_shapes.append(None) found_types.append(None) - elif self._nodes[i].op == 'output': + elif value.op == 'output': # Instead of putting output index, refer to its target - uargs = self.unpack_containers(self._nodes[i].args) + uargs = self.unpack_containers(value.args) self._outputs = [(arg[0], self._nodes.index(arg[1])) for arg in uargs if arg[1] is not None] for idx, shape in enumerate(found_shapes): if shape is not None: new_shape = [] - for dim in range(0, len(shape)): - if (dynamic_shapes or type(shape[dim]).__name__ == "SymInt"): + for dim in shape: + if (dynamic_shapes or type(dim).__name__ == "SymInt"): new_shape.append(-1) else: - new_shape.append(shape[dim]) + new_shape.append(dim) found_shapes[idx] = torch.Size(new_shape) if not input_shapes or len(input_shapes) == 0: @@ -84,99 +216,20 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, self._input_signature = list(input_params) elif issubclass(type(pt_module), torch.fx.Node): - self._nodes = nodes # passed from outer context # FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes self._outputs = [("", self._nodes.index(pt_module))] - # None in inputs mean the input is inlined or None (also considered inlined) - self._inputs = [self._nodes.index( - arg) if arg in self._nodes else InlinedInput(arg) for arg in pt_module.args] - - # FIXME: Find a better way to pass nested tuples to OV frontend. This is a temporary solution to flatten arguments. - new_inputs = [] self.input_types = [] - for i in range(len(pt_module.args)): - if isinstance(pt_module.args[i], (list, tuple)) and any([isinstance(a, torch.fx.Node) for a in pt_module.args[i]]): - for arg in pt_module.args[i]: - if arg in self._nodes: - new_inputs.append(self._nodes.index(arg)) - else: - new_inputs.append(InlinedInput(arg)) - self.input_types.append(OVAny(DecoderType.List( - TorchFXPythonDecoder.get_type_for_value(arg)))) + for arg in pt_module.args: + if isinstance(arg, torch.fx.Node): + self._inputs.append(self._nodes.index(arg)) else: - v = self._inputs[i] - new_inputs.append(v) - self.input_types.append( - TorchFXPythonDecoder.get_type_for_value(v.data if isinstance(v, InlinedInput) else self._nodes[v])) - self._inputs = new_inputs - - def inputs(self): - # Consider 0 a special case which may mean the input is inlined, but not guaranteed - return [x if not isinstance(x, InlinedInput) else 0 for x in self._inputs] - - def is_input_inlined(self, index): - return isinstance(self._inputs[index], InlinedInput) - - @staticmethod - def unpack_containers(arg): - if isinstance(arg, (tuple, list)): - res = [] - for e in arg: - res.extend(TorchFXPythonDecoder.unpack_containers(e)) - return res - elif isinstance(arg, dict): - res = [] - for k, e in arg.items(): - unpacked = TorchFXPythonDecoder.unpack_containers(e) - if len(unpacked) == 1: - unpacked[0] = (k, unpacked[0][1]) - res.extend(unpacked) - return res - else: - return [("", arg)] - - @staticmethod - def arg_to_constant(arg): - if isinstance(arg, list): - if len(arg) > 0: - return make_constant(pt_to_ov_type_map[type( - arg[0]).__name__], Shape([len(arg)]), arg) - else: - # TODO: which type should we use if list is empty? Need a signaling value here - return make_constant(OVType.i32, Shape([0]), []) - elif isinstance(arg, bool): - return make_constant(OVType.boolean, Shape([]), [arg]) - elif isinstance(arg, int): - return make_constant(OVType.i64, Shape([]), [arg]) - elif isinstance(arg, float): - return make_constant(OVType.f32, Shape([]), [arg]) - elif isinstance(arg, str): - u8_tensor = torch.frombuffer(str.encode(arg), dtype=torch.uint8) - return torch_tensor_to_ov_const(u8_tensor, shared_memory=True) - return None - - def inlined_input(self, index): - assert index < len(self._inputs), "Requested input doesn't exist" - assert isinstance( - self._inputs[index], InlinedInput), "Requested input which is not inlined" - arg = self._inputs[index].data - assert arg is not None, f"Requested None inlined input for op {self.get_op_type()}" - constant = None - constant = self.arg_to_constant(arg) - - if constant is not None: - return constant.outputs() - else: - return [] - - def input(self, index): # TODO: remove - return self.inputs()[index] # TODO: find specialized method - - def get_input_debug_name(self, index): - return "input"+str(index) + # Not a node, consider it inlined + self._inputs.append(InlinedInput(arg)) + self.input_types.append( + BaseFXDecoder.get_type_for_value(arg)) def get_input_signature_name(self, index: int) -> str: if self._input_signature is not None and index < len(self._input_signature): @@ -225,24 +278,6 @@ def get_shape_for_value(self, value): return PartialShape(len(value.meta['tensor_meta'].shape) * [-1]) return PartialShape.dynamic() - @staticmethod - def get_type_for_value(value): - if issubclass(type(value), torch.fx.Node): - if ('tensor_meta' in value.meta.keys()): - if value.meta['tensor_meta'] and isinstance(value.meta['tensor_meta'], torch.Tensor): - pt_type = value.meta['tensor_meta'].dtype - if str(pt_type) in pt_to_ov_type_map: - ov_type = pt_to_ov_type_map[str(pt_type)] - return OVAny(ov_type) - return OVAny(OVType.dynamic) - elif isinstance(value, int): - return OVAny(DecoderType.PyScalar(OVAny(OVType.i64))) - elif isinstance(value, float): - return OVAny(DecoderType.PyScalar(OVAny(OVType.f32))) - elif isinstance(value, bool): - return OVAny(DecoderType.PyScalar(OVAny(OVType.boolean))) - return OVAny(OVType.dynamic) - def get_attribute(self, name): if name in self.pt_module.kwargs: attr = self.pt_module.kwargs[name] @@ -272,12 +307,6 @@ def get_named_input(self, name): return self._nodes.index(arg) raise RuntimeError("This input is not a Node") - def get_subgraph_size(self): - return len(self.get_subgraphs()) - - def decoder_type_name(self) -> str: - return "fx" - def visit_subgraph(self, node_visitor): # make sure topological order is satisfied for node in self._nodes: @@ -290,9 +319,6 @@ def visit_subgraph(self, node_visitor): self.m_decoders.append(decoder) node_visitor(decoder) - def get_subgraphs(self): - return [] - def get_subgraph_decoder(self, index): decoder = TorchFXPythonDecoder(self.get_subgraphs()[index], self.fx_gm, @@ -308,9 +334,6 @@ def get_op_type(self): else: return 'UNKNOWN_TYPE_' + str(self.pt_module.op) - def get_schema(self): - return 'NONE' - def outputs(self): return [o[1] for o in self._outputs] @@ -336,16 +359,12 @@ def output_list_size(self): max_out_id = user.args[1] return max_out_id + 1 - def output(self, index): - return self.outputs()[index] - def mark_node(self, node): name = self.get_op_type() if "FrameworkNode" not in node.get_type_name(): name += "/" + node.get_type_name() node.set_friendly_name(self.pt_module.name + "/" + name) - if self.mark_node_callback is not None: - self.mark_node_callback(self, node) + super().mark_node(node) return node def as_constant(self): @@ -355,9 +374,6 @@ def as_constant(self): ov_const = torch_tensor_to_ov_const(ret, shared_memory=True) return ov_const.outputs() - def as_string(self): - return None - def input_is_none(self, index): if index >= len(self._inputs) or (isinstance(self._inputs[index], InlinedInput) and self._inputs[index].data is None): return True @@ -368,9 +384,69 @@ def input_is_none(self, index): def debug(self): self.pt_module.print() - def may_produce_alias(self, in_index: int, out_index: int) -> bool: + +class InlinedInput: + """ + Represents an inlined input. This is a special case + where the input is not a node, but a constant value. + """ + + def __init__(self, data) -> None: + self.data = data + + +class InlinedInputDecoder (BaseFXDecoder): + """ + Decoder for inlined inputs in PyTorch FX graphs. + """ + + def __init__(self, inlined_input: InlinedInput, nodes=None, mark_node_callback=None) -> None: + super().__init__(mark_node_callback) + self.inlined_input = inlined_input + self._nodes = nodes + self.is_const = not (isinstance(inlined_input.data, (list, tuple)) and any( + isinstance(a, torch.fx.Node) for a in inlined_input.data)) + if not self.is_const: + self._inputs = [nodes.index(x) if isinstance( + x, torch.fx.Node) else InlinedInput(x) for x in inlined_input.data] + + def get_op_type(self): + # return specific type for inlined inputs + if not self.is_const: + return "prim::ListConstruct" + return "inlined.constant.default" + + def outputs(self): + return [0] + + def num_of_outputs(self): + return 1 + + def get_input_shape(self, index): + return PartialShape.dynamic() + + def get_input_type(self, index): + return OVAny(OVType.dynamic) + + def get_output_type(self, index): + return OVAny(OVType.dynamic) + + def input_is_none(self, index): + if index < len(self._inputs) and isinstance(self._inputs[index], InlinedInput): + return self._inputs[index].data is None return False - def get_rt_info(self): - rt_info = {} - return rt_info + def as_constant(self): + arg = self.inlined_input.data + constant = BaseFXDecoder.arg_to_constant(arg) + if constant is not None: + return constant.outputs() + return [] + + def mark_node(self, node): + name = self.get_op_type() + if "FrameworkNode" not in node.get_type_name(): + name += "/" + node.get_type_name() + node.set_friendly_name(name) + super().mark_node(node) + return node diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 6d8fdb1658793e..9224c2b1fe76bc 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -515,12 +515,12 @@ def may_produce_alias(self, in_index: int, out_index: int) -> bool: # Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node return False - def inlined_input(self, index): - return [] - def is_input_inlined(self, index): return False + def get_inlined_input_decoder(self, index): + return None + def get_attribute(self, name): return OVAny(None) diff --git a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp index 3b044011e56de5..2a1d9dae127375 100644 --- a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp +++ b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp @@ -114,14 +114,14 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder { PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, may_produce_alias, in_index, out_index); } - ov::OutputVector inlined_input(size_t index) const override { - PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_input, index); - } - bool is_input_inlined(size_t index) const override { PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, is_input_inlined, index); } + std::shared_ptr get_inlined_input_decoder(size_t index) const override { + PYBIND11_OVERRIDE_PURE(std::shared_ptr, TorchDecoder, get_inlined_input_decoder, index); + } + ov::Any get_attribute(const std::string &name) const override{ PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_attribute, name); } diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index 9fd537c7813d00..a2c92d1aa5cab9 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -113,14 +113,13 @@ class TorchDecoder : public IDecoder { /// \brief Returns if output may contain alias of input in AliasDB virtual bool may_produce_alias(size_t in_index, size_t out_index) const = 0; - /// Returns new nodes for inputs inlined in the op itself - // Used in Torch.FX decoder - virtual OutputVector inlined_input(size_t index) const = 0; - /// Returns if input is inlined // Used in Torch.FX decoder virtual bool is_input_inlined(size_t index) const = 0; + /// Return decoder for inlined input + virtual std::shared_ptr get_inlined_input_decoder(size_t index) const = 0; + /// Returns named attribute as Any. For example kwargs input for FX graph virtual ov::Any get_attribute(const std::string& name) const = 0; diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp index d97c2ef694a1eb..ad4d1eb3f93412 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp @@ -51,44 +51,9 @@ class NodeContext : public frontend::NodeContext { // Search for input in tensor map and return an output port for already converted op // TODO: int due to base class uses it, but naturally it should be size_t for PT - Output get_input(int index) const override { - size_t index_ = static_cast(index); - FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index_), - "Input doesn't exist with index: ", - index, - " for operation ", - get_op_type()); - auto input = m_decoder_inputs.at(index); - if (input == 0) { - // Case when input can be inlined (possible only for fx decoder) - if (m_decoder->is_input_inlined(index_)) { - auto inlined_input = m_decoder->inlined_input(index_); - FRONT_END_GENERAL_CHECK(inlined_input.size() == 1, - "Incorrect inlined input with index: ", - index, - " for operation ", - get_op_type()); - return inlined_input[0]; - } - } - FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist."); - return m_tensor_map->at(input); - } + Output get_input(int index) const override; - Output get_input(const std::string& name) const override { - FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist"); - auto attr = get_attribute_as_any(name); - if (attr.is>()) { - // Case when input is constant value - return attr.as>(); - } else if (attr.is()) { - // None means input is unknown type, most likely a Node - auto input = m_decoder->get_named_input(name); - FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist."); - return m_tensor_map->at(input); - } - FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node."); - } + Output get_input(const std::string& name) const override; Any get_values_from_const_input(int index) const override; diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index 6a8c370ef2b410..46fd695339f074 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -145,39 +145,65 @@ std::shared_ptr NodeContext::convert_subgraph(size_t index) const { return model; } +Output NodeContext::get_input(int index) const { + size_t index_ = static_cast(index); + auto input = m_decoder_inputs.at(index); + if (input == 0) { + // Case when input can be inlined (possible only for fx decoder) + if (m_decoder->is_input_inlined(index_)) { + if (m_decoder->input_is_none(index_)) { + // some operations like aten.index.Tensor can have None inputs + auto dummy_decoder = std::make_shared("torch::None", 1); + auto fw_node = std::make_shared(dummy_decoder, OutputVector{}); + auto attrs = fw_node->get_attrs(); + attrs["none_value"] = ""; + attrs[PtFrameworkNode::failed_conversion_key] = + "None constant cannot be converted to OpenVINO opset and should be removed by consuming " + "operation."; + fw_node->set_attrs(attrs); + return fw_node->output(0); + } else { + auto inlined_decoder = m_decoder->get_inlined_input_decoder(index_); + auto inlined_ctx = NodeContext(inlined_decoder, + m_ext_tensor_map, + m_tensor_map, + m_external_parameters, + m_mutated_tensors, + m_translate_session); + auto inlined_input = m_translate_session->convert_node(inlined_ctx); + FRONT_END_GENERAL_CHECK(inlined_input.size() == 1, + "Incorrect inlined input with index: ", + index, + " for operation ", + get_op_type()); + return inlined_input[0]; + } + } + } + auto tensor_it = m_tensor_map->find(input); + FRONT_END_GENERAL_CHECK(tensor_it != m_tensor_map->end(), "No tensor corresponding input: ", input, " exist."); + return tensor_it->second; +} + +Output NodeContext::get_input(const std::string& name) const { + FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist"); + auto attr = get_attribute_as_any(name); + if (attr.is>()) { + // Case when input is constant value + return attr.as>(); + } else if (attr.is()) { + // None means input is unknown type, most likely a Node + auto input = m_decoder->get_named_input(name); + FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist."); + return m_tensor_map->at(input); + } + FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node."); +} + OutputVector NodeContext::inputs() const { OutputVector res; for (size_t i = 0; i < m_decoder_inputs.size(); i++) { - auto input = m_decoder_inputs.at(i); - if (input == 0) { - // Case when input can be inlined (possible only for fx decoder) - if (m_decoder->is_input_inlined(i)) { - if (input_is_none(i)) { - // some operations like aten.index.Tensor can have None inputs - auto dummy_decoder = std::make_shared("torch::None", 1); - auto fw_node = std::make_shared(dummy_decoder, OutputVector{}); - auto attrs = fw_node->get_attrs(); - attrs["none_value"] = ""; - attrs[PtFrameworkNode::failed_conversion_key] = - "None constant cannot be converted to OpenVINO opset and should be removed by consuming " - "operation."; - fw_node->set_attrs(attrs); - res.push_back(fw_node->output(0)); - } else { - auto inlined_input = m_decoder->inlined_input(i); - FRONT_END_GENERAL_CHECK(inlined_input.size() == 1, - "Incorrect inlined input with index: ", - i, - " for operation ", - get_op_type()); - res.push_back(inlined_input[0]); - } - continue; - } - } - auto tensor_it = m_tensor_map->find(input); - FRONT_END_GENERAL_CHECK(tensor_it != m_tensor_map->end(), "No tensor corresponding input: ", input, " exist."); - res.push_back(tensor_it->second); + res.push_back(get_input(i)); } return res; } diff --git a/src/frontends/pytorch/src/op/as_strided.cpp b/src/frontends/pytorch/src/op/as_strided.cpp index 20efb3ba5cb684..c3ae902a81defa 100644 --- a/src/frontends/pytorch/src/op/as_strided.cpp +++ b/src/frontends/pytorch/src/op/as_strided.cpp @@ -20,21 +20,20 @@ namespace pytorch { namespace op { using namespace ov::op; + +namespace { bool compare_strides(const std::tuple& a, const std::tuple& b) { return std::get<0>(a) > std::get<0>(b); } -OutputVector translate_as_strided(const NodeContext& context) { - // "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)" - num_inputs_check(context, 3, 4); - auto decoder = context.get_decoder(); - auto input = context.get_input(0); + +OutputVector translate_as_strided_common(const NodeContext& context, + const Output& input, + const std::vector& input_strides, + const std::deque>& sizes, + const std::deque>& strides) { auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); - auto input_strides = decoder->get_input_strides(0); - PYTORCH_OP_CONVERSION_CHECK(input_strides.size() != 0, - "aten::as_strided: Couldn't retrieve input stride information from torchscript."); - std::vector idxs(input_strides.size()); iota(idxs.begin(), idxs.end(), 0); std::vector> stride_idxs(idxs.size()); @@ -53,26 +52,6 @@ OutputVector translate_as_strided(const NodeContext& context) { context.mark_node(v0::Constant::create(element::i32, Shape{transpose_idx.size()}, transpose_idx)); auto transposed_input = context.mark_node(std::make_shared(input, transpose_idx_const)); auto flat_input = context.mark_node(std::make_shared(transposed_input, const_neg_1, false)); - std::deque> sizes; - std::deque> strides; - if (std::dynamic_pointer_cast(context.get_input_from_visible_context(1).get_node_shared_ptr())) { - auto input_vector = context.const_input>(1); - std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) { - auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val})); - sizes.push_front(const_input); - }); - } else { - sizes = get_list_as_outputs(context.get_input(1)); - } - if (std::dynamic_pointer_cast(context.get_input_from_visible_context(2).get_node_shared_ptr())) { - auto input_vector = context.const_input>(2); - std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) { - auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val})); - strides.push_front(const_input); - }); - } else { - strides = get_list_as_outputs(context.get_input(2)); - } auto offset = const_0->output(0); if (!context.input_is_none(3)) { offset = get_input_as_i32(context, 3); @@ -84,12 +63,12 @@ OutputVector translate_as_strided(const NodeContext& context) { auto strides_length_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides.size()})); auto ones_strides_len = context.mark_node(std::make_shared(const_1, strides_length_const)); auto indices = const_0; - std::for_each(strides.rbegin(), strides.rend(), [&](Output& stride) { + std::for_each(strides.rbegin(), strides.rend(), [&](const Output& stride) { auto const_num_iter = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides_size - i})); - stride = context.mark_node(std::make_shared(stride, element::i32)); + auto stride_conv = context.mark_node(std::make_shared(stride, element::i32)); auto size = sizes.at(strides_size - i); auto range = context.mark_node(std::make_shared(const_0, size, const_1, element::i32)); - range = context.mark_node(std::make_shared(range, stride)); + range = context.mark_node(std::make_shared(range, stride_conv)); auto iteration_shape = context.mark_node( std::make_shared(ones_strides_len, const_num_iter, const_neg_1, const_0)); range = context.mark_node(std::make_shared(range, iteration_shape, false)); @@ -99,7 +78,41 @@ OutputVector translate_as_strided(const NodeContext& context) { indices = context.mark_node(std::make_shared(indices, offset)); auto gather = context.mark_node(std::make_shared(flat_input, indices, const_0)); return {gather}; +} +} // namespace + +OutputVector translate_as_strided(const NodeContext& context) { + // "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)" + num_inputs_check(context, 3, 4); + auto decoder = context.get_decoder(); + auto input = context.get_input(0); + auto input_strides = decoder->get_input_strides(0); + PYTORCH_OP_CONVERSION_CHECK(input_strides.size() != 0, + "aten::as_strided: Couldn't retrieve input stride information from torchscript."); + + std::deque> sizes; + std::deque> strides; + if (std::dynamic_pointer_cast(context.get_input_from_visible_context(1).get_node_shared_ptr())) { + auto input_vector = context.const_input>(1); + std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) { + auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val})); + sizes.push_front(const_input); + }); + } else { + sizes = get_list_as_outputs(context.get_input(1)); + } + if (std::dynamic_pointer_cast(context.get_input_from_visible_context(2).get_node_shared_ptr())) { + auto input_vector = context.const_input>(2); + std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) { + auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val})); + strides.push_front(const_input); + }); + } else { + strides = get_list_as_outputs(context.get_input(2)); + } + return translate_as_strided_common(context, input, input_strides, sizes, strides); }; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index d4f12cae258ad8..b84054e21f67ed 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -104,18 +104,11 @@ OutputVector translate_cat(const NodeContext& context) { }; OutputVector translate_cat_fx(const NodeContext& context) { - // This translator is only needed to get axis as constant from external scope - num_inputs_check(context, 1, context.get_input_size()); - std::deque> list_elems; - for (size_t i = 0; i < context.get_input_size() - 1; i++) { - list_elems.push_back(context.get_input(static_cast(i))); - } + num_inputs_check(context, 1, 2); + const auto&& list_elems = get_list_as_outputs(context.get_input(0)); int64_t axis = 0; - if (!context.get_input_type(context.get_input_size() - 1).is()) { - // axis can be not present and that means that last input will have List type - axis = context.const_input(context.get_input_size() - 1); - } else { - list_elems.push_back(context.get_input(static_cast(context.get_input_size() - 1))); + if (!context.input_is_none(1)) { + axis = context.const_input(1); } return translate_cat_common(context, list_elems, axis, true); }; diff --git a/src/frontends/pytorch/src/op/expand.cpp b/src/frontends/pytorch/src/op/expand.cpp index a6bc239df96562..6c07c3b53318c8 100644 --- a/src/frontends/pytorch/src/op/expand.cpp +++ b/src/frontends/pytorch/src/op/expand.cpp @@ -43,24 +43,6 @@ OutputVector translate_expand_as(const NodeContext& context) { return {context.mark_node(std::make_shared(x, shape, BroadcastType::BIDIRECTIONAL))}; }; -OutputVector translate_expand_fx(const NodeContext& context) { - auto num_inputs = context.get_input_size(); - num_inputs_check(context, 2, num_inputs); - auto x = context.get_input(0); - std::vector shape_vec; - if (context.get_input_type(1).is()) { - auto concat = concat_list_from_inputs(context, 1, num_inputs); - return base_expand(context, x, concat); - } else { - auto x = context.get_input(0); - auto sizes = context.get_input(1); - // TODO: figure out what implicit means - PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input(2) == false, - "Unexpected value of implicit for expand operation"); - return base_expand(context, x, sizes); - } -}; - } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index 799c5d6feaebbe..f0693dd07b289c 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -76,14 +76,9 @@ OutputVector translate_full_fx(const NodeContext& context) { // aten.full.default([16, 16], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), // pin_memory = False) auto num_inputs = context.get_input_size(); - num_inputs_check(context, 2, num_inputs); - ov::Output sizes; - if (context.get_input_type(0).is()) { - sizes = concat_list_from_inputs(context, 0, num_inputs - 1); - } else { - sizes = context.get_input(0); - } - auto value = context.get_input(static_cast(num_inputs - 1)); + num_inputs_check(context, 2, 2); + auto sizes = get_input_concat_if_list(context, 0); + auto value = context.get_input(1); auto filled_tensor = base_translate_full(context, sizes, value); if (context.has_attribute("dtype")) { diff --git a/src/frontends/pytorch/src/op/index.cpp b/src/frontends/pytorch/src/op/index.cpp index 880e0acee0f983..b6c1ccaba963c3 100644 --- a/src/frontends/pytorch/src/op/index.cpp +++ b/src/frontends/pytorch/src/op/index.cpp @@ -68,16 +68,9 @@ OutputVector translate_index(const NodeContext& context) { }; OutputVector translate_index_fx(const NodeContext& context) { - num_inputs_check(context, 2, context.get_input_size()); + num_inputs_check(context, 2, 2); auto x = context.get_input(0); - std::deque> list_elems; - for (size_t i = 1; i < context.get_input_size(); i++) { - Output index; - if (!context.input_is_none(i)) { - index = context.get_input(static_cast(i)); - } - list_elems.push_back(index); - } + auto list_elems = get_list_as_outputs(context.get_input(1)); ov::pass::NodeRegistry rg; auto rank = x.get_partial_shape().rank(); if (rank.is_dynamic()) { @@ -86,7 +79,13 @@ OutputVector translate_index_fx(const NodeContext& context) { // index transformation supports only tensors with static rank PYTORCH_OP_CONVERSION_CHECK(rank.is_static(), "Dynamic rank for aten::index input is not supported."); - OutputVector ids{list_elems.begin(), list_elems.end()}; + OutputVector ids; + for (size_t i = 0; i < list_elems.size(); ++i) { + if (!is_none_node(list_elems[i])) + ids.push_back(list_elems[i]); + else + ids.push_back(Output()); + } ov::Output res; bool use_input_as_output = true; index_tensor_on_list(rg, x, ids, rank, res, use_input_as_output); diff --git a/src/frontends/pytorch/src/op/repeat.cpp b/src/frontends/pytorch/src/op/repeat.cpp new file mode 100644 index 00000000000000..67dbe6fda31825 --- /dev/null +++ b/src/frontends/pytorch/src/op/repeat.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/tile.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_repeat_fx(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto repeats = get_input_concat_if_list(context, 1); + return {context.mark_node(std::make_shared(context.get_input(0), repeats))}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/repeat_interleave.cpp b/src/frontends/pytorch/src/op/repeat_interleave.cpp index 2b5214ae765af1..1a02e82c501307 100644 --- a/src/frontends/pytorch/src/op/repeat_interleave.cpp +++ b/src/frontends/pytorch/src/op/repeat_interleave.cpp @@ -68,7 +68,7 @@ OutputVector translate_repeat_interleave(const NodeContext& context) { } } else { // repeats is not Constant or single element constant - // Curently we support only case when repeats contains only one element. Otherwise next Reshape will fail. + // Currently we support only case when repeats contains only one element. Otherwise next Reshape will fail. auto repeats_input = context.mark_node(std::make_shared(context.get_input(1), element::i32)); repeats_input = context.mark_node(std::make_shared(repeats_input, const_1_list, false)); auto repeats = context.mark_node(std::make_shared(OutputVector{repeats_input, const_1_list}, 0)); diff --git a/src/frontends/pytorch/src/op/reshape.cpp b/src/frontends/pytorch/src/op/reshape.cpp index a2c1a43a4fcb53..728998d31f3244 100644 --- a/src/frontends/pytorch/src/op/reshape.cpp +++ b/src/frontends/pytorch/src/op/reshape.cpp @@ -26,30 +26,6 @@ OutputVector translate_reshape(const NodeContext& context) { return {context.mark_node(reshape)}; }; -OutputVector translate_reshape_fx(const NodeContext& context) { - // Schema: aten.view.default(Tensor input, int[] shape) -> Tensor - auto num_inputs = context.get_input_size(); - num_inputs_check(context, 2, num_inputs); - std::vector shape_vec; - if (context.get_input_type(1).is()) { - auto concat = concat_list_from_inputs(context, 1, num_inputs); - auto reshape = std::make_shared(context.get_input(0), concat, true); - return {context.mark_node(reshape)}; - } else { - auto shape_input = context.get_input(1); - if (shape_input.get_partial_shape().rank().is_dynamic() || - shape_input.get_partial_shape().rank().get_length() == 0) { - shape_vec.push_back(0); - auto shape_const = ov::op::v0::Constant::create(element::i32, Shape{1}, shape_vec); - auto result = - context.mark_node(std::make_shared(context.get_input(0), shape_const, true)); - return {result}; - } - auto reshape = std::make_shared(context.get_input(0), context.get_input(1), true); - return {context.mark_node(reshape)}; - } -}; - } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a73c13814d7663..721929f582e9c5 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -273,7 +273,6 @@ OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_div_fx); OP_CONVERTER(translate_div_fx_); OP_CONVERTER(translate_embedding_bag_fx); -OP_CONVERTER(translate_expand_fx); OP_CONVERTER(translate_eye_fx); OP_CONVERTER(translate_fake_quantize_per_channel_affine_fx); OP_CONVERTER(translate_fake_quantize_per_tensor_affine_fx); @@ -297,7 +296,7 @@ OP_CONVERTER(translate_new_zeros_fx); OP_CONVERTER(translate_ones_fx); OP_CONVERTER(translate_ones_like_fx); OP_CONVERTER(translate_reflection_pad_nd_fx); -OP_CONVERTER(translate_reshape_fx); +OP_CONVERTER(translate_repeat_fx); OP_CONVERTER(translate_rsub_fx); OP_CONVERTER(translate_scalar_tensor_fx); OP_CONVERTER(translate_scaled_dot_product_attention_fx); @@ -763,7 +762,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten._scaled_dot_product_flash_attention_for_cpu.default", op::translate_scaled_dot_product_attention_fx}, {"aten._softmax.default", op::translate_softmax_fx}, {"aten._to_copy.default", op::translate_to_fx}, - {"aten._unsafe_view.default", op::translate_reshape_fx}, + {"aten._unsafe_view.default", op::translate_reshape}, {"aten.abs.default", op::translate_1to1_match_1_inputs}, {"aten.acos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.acosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -838,7 +837,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten.erfc.default", op::translate_erfc}, {"aten.exp.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.expm1.default", op::translate_expm1}, - {"aten.expand.default", op::translate_expand_fx}, + {"aten.expand.default", op::translate_expand}, {"aten.eye.m", op::translate_eye_fx}, {"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx}, {"aten.fill.Scalar", op::translate_fill}, @@ -932,7 +931,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten.reflection_pad3d.default", op::translate_reflection_pad_nd_fx}, {"aten.relu.default", op::translate_1to1_match_1_inputs}, {"aten.relu_.default", op::inplace_op>}, - {"aten.repeat.default", op::translate_1to1_match_2_inputs}, + {"aten.repeat.default", op::translate_repeat_fx}, {"aten.roll.default", op::translate_roll}, {"aten.rsqrt.default", op::translate_rsqrt}, {"aten.rsub.Scalar", op::translate_rsub_fx}, @@ -978,7 +977,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten.upsample_nearest2d.default", op::translate_upsample_nearest2d}, {"aten.var.correction", op::translate_var_fx}, {"aten.var_mean.correction", op::translate_var_mean_fx}, - {"aten.view.default", op::translate_reshape_fx}, + {"aten.view.default", op::translate_reshape}, {"aten.where.self", op::translate_where}, {"aten.zeros.default", op::translate_zeros_fx}, {"aten.zeros.names", op::translate_zeros_fx}, @@ -990,6 +989,7 @@ const std::unordered_map get_supported_ops_fx() { {"quantized_decomposed.quantize_per_channel.default", op::translate_quantize_per_channel_fx}, {"quantized_decomposed.dequantize_per_tensor.default", op::skip_node}, {"quantized_decomposed.dequantize_per_channel.default", op::skip_node}, + {"inlined.constant.default", op::translate_constant}, // this is a custom ov type }; }; diff --git a/src/frontends/pytorch/src/translate_session.hpp b/src/frontends/pytorch/src/translate_session.hpp index df669dbabe1fae..21497459d85126 100644 --- a/src/frontends/pytorch/src/translate_session.hpp +++ b/src/frontends/pytorch/src/translate_session.hpp @@ -51,9 +51,8 @@ class TranslateSession { // and to the output produced during conversion of this node std::map, Output>> m_may_be_alias; -private: OutputVector convert_node(const NodeContext& context); - +private: const frontend::InputModel::Ptr m_input_model; const std::unordered_map& m_translator_map; std::shared_ptr m_telemetry; diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 5cc7ec21f30911..7d3f545b1ce211 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -639,30 +639,6 @@ Output masked_fill(ov::pass::NodeRegistry& rg, return rg.make(bool_mask, _value, data); } -Output concat_list_from_inputs(const NodeContext& context, size_t begin, size_t end) { - OutputVector list_elems; - for (size_t i = begin; i < end; i++) { - if (context.get_input_type(i).as().element_type.is()) { - auto const_val = context.const_input(i); - std::vector dim_vec; - dim_vec.push_back(const_val); - auto dim_const = v0::Constant::create(element::i64, Shape{1}, dim_vec); - list_elems.push_back(dim_const); - } else { - auto input_dim = context.get_input(static_cast(i)); - if (input_dim.get_partial_shape().rank() == 0) { - auto zero = v0::Constant::create(element::i32, Shape{}, {0}); - auto unsqueezed_dim = context.mark_node(std::make_shared(input_dim, zero)); - list_elems.push_back(unsqueezed_dim); - } else { - list_elems.push_back(input_dim); - } - } - } - auto concat = std::make_shared(list_elems, 0); - return concat; -} - Output masked_select(const NodeContext& context, const Output& data, const Output& mask) { auto input_order = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0})); auto nonzero = context.mark_node(std::make_shared(mask)); diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 9346b9e18b94a3..1a57c49932b0e9 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -125,8 +125,6 @@ Output masked_fill(ov::pass::NodeRegistry& rg, const Output& mask, const Output& value); -Output concat_list_from_inputs(const NodeContext& context, size_t begin, size_t end); - Output masked_select(const NodeContext& context, const Output& data, const Output& mask); Output flatten(ov::pass::NodeRegistry& rg, const Output& value, size_t axis); @@ -310,11 +308,11 @@ class DummyDecoder : public TorchDecoder { virtual bool may_produce_alias(size_t in_index, size_t out_index) const override { FRONT_END_NOT_IMPLEMENTED(may_produce_alias); } - bool is_input_inlined(size_t index) const override { + virtual bool is_input_inlined(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(is_input_inlined); } - virtual OutputVector inlined_input(size_t index) const override { - FRONT_END_NOT_IMPLEMENTED(inlined_input); + virtual std::shared_ptr get_inlined_input_decoder(size_t index) const override { + FRONT_END_NOT_IMPLEMENTED(get_inlined_input_decoder); } virtual ov::Any get_attribute(const std::string& name) const override { FRONT_END_NOT_IMPLEMENTED(get_attribute); diff --git a/tests/model_hub_tests/pytorch/test_timm.py b/tests/model_hub_tests/pytorch/test_timm.py index 0c151e804720ca..5c6a76fadb85e8 100644 --- a/tests/model_hub_tests/pytorch/test_timm.py +++ b/tests/model_hub_tests/pytorch/test_timm.py @@ -51,9 +51,12 @@ class TestTimmConvertModel(TestTorchConvertModel): def load_model(self, model_name, model_link): m = timm.create_model(model_name, pretrained=True) cfg = timm.get_pretrained_cfg(model_name) - shape = [1] + list(cfg.input_size) - self.example = (torch.randn(shape),) - self.inputs = (torch.randn(shape),) + shape = list(cfg.input_size) + self.example = (torch.randn([2] + shape),) + self.inputs = (torch.randn([3] + shape),) + if getattr(self, "mode", None) == "export": + batch = torch.export.Dim("batch", min=1, max=3) + self.export_kwargs = {"dynamic_shapes": {"x": {0: batch}}} return m def infer_fw_model(self, model_obj, inputs): diff --git a/tests/model_hub_tests/pytorch/timm_models b/tests/model_hub_tests/pytorch/timm_models index 9732569b648245..535368850e2623 100644 --- a/tests/model_hub_tests/pytorch/timm_models +++ b/tests/model_hub_tests/pytorch/timm_models @@ -126,13 +126,13 @@ efficientnet_es_pruned.in1k,None efficientnet_lite0.ra_in1k,None efficientnetv2_rw_s.ra2_in1k,None efficientnetv2_rw_t.ra2_in1k,None -efficientvit_b0.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. -efficientvit_b1.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. -efficientvit_b2.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. -efficientvit_b3.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. -efficientvit_l1.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. -efficientvit_l2.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. -efficientvit_l3.r224_in1k,None,xfail_export,SpecViolationError: Node.meta _enter_autocast is missing val field. +efficientvit_b0.r224_in1k,None +efficientvit_b1.r224_in1k,None +efficientvit_b2.r224_in1k,None +efficientvit_b3.r224_in1k,None +efficientvit_l1.r224_in1k,None +efficientvit_l2.r224_in1k,None +efficientvit_l3.r224_in1k,None efficientvit_m0.r224_in1k,None efficientvit_m1.r224_in1k,None efficientvit_m2.r224_in1k,None @@ -147,8 +147,8 @@ eva02_base_patch14_224.mim_in22k,None eva02_base_patch16_clip_224.merged2b,None eva02_large_patch14_clip_224.merged2b,None fastvit_ma36.apple_dist_in1k,None -fastvit_mci0.apple_mclip,None -fastvit_mci1.apple_mclip,None +fastvit_mci0.apple_mclip,None,xfail,Accuracy validation failed +fastvit_mci1.apple_mclip,None,xfail,Accuracy validation failed fastvit_mci2.apple_mclip,None,xfail,Accuracy validation failed fastvit_s12.apple_dist_in1k,None fastvit_sa12.apple_dist_in1k,None @@ -524,7 +524,7 @@ vit_base_patch16_224.augreg2_in21k_ft_in1k,None vit_base_patch16_224_miil.in21k,None vit_base_patch16_clip_224.datacompxl,None vit_base_patch16_clip_quickgelu_224.metaclip_2pt5b,None -vit_base_patch16_rope_reg1_gap_256.sbb_in1k,None,xfail_trace,Argument shapes are inconsistent +vit_base_patch16_rope_reg1_gap_256.sbb_in1k,None vit_base_patch16_rpn_224.sw_in1k,None vit_base_patch16_siglip_224.webli,None vit_base_patch16_siglip_gap_224.webli,None @@ -535,7 +535,7 @@ vit_base_patch8_224.augreg2_in21k_ft_in1k,None vit_base_r50_s16_224.orig_in21k,None vit_betwixt_patch16_reg1_gap_256.sbb_in1k,None vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k,None -vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k,None,xfail_trace,Argument shapes are inconsistent +vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k,None vit_betwixt_patch32_clip_224.tinyclip_laion400m,None vit_huge_patch14_224.mae,None vit_huge_patch14_gap_224.in1k_ijepa,None @@ -549,7 +549,7 @@ vit_medium_patch16_gap_240.sw_in12k,None vit_medium_patch16_reg1_gap_256.sbb_in1k,None vit_medium_patch16_reg4_gap_256.sbb_in12k,None vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k,None -vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k,None,xfail_trace,Argument shapes are inconsistent +vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k,None vit_pwee_patch16_reg1_gap_256.sbb_in1k,None vit_relpos_base_patch16_224.sw_in1k,None vit_relpos_base_patch16_clsgap_224.sw_in1k,None diff --git a/tests/model_hub_tests/pytorch/torch_utils.py b/tests/model_hub_tests/pytorch/torch_utils.py index f7a3837bfa2695..48865e151910fb 100644 --- a/tests/model_hub_tests/pytorch/torch_utils.py +++ b/tests/model_hub_tests/pytorch/torch_utils.py @@ -1,9 +1,9 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import os import pytest import torch -import os from models_hub_common.test_convert_model import TestConvertModel from models_hub_common.utils import get_models_list from openvino import convert_model @@ -46,11 +46,12 @@ def extract_unsupported_ops_from_exception(e: str) -> list: class TestTorchConvertModel(TestConvertModel): cached_model = None + def setup_class(self): torch.set_grad_enabled(False) def load_model(self, model_name, model_link): - raise "load_model is not implemented" + raise RuntimeError("load_model is not implemented") def get_inputs_info(self, model_obj): return None @@ -69,12 +70,15 @@ def convert_model_impl(self, model_obj): model_obj.eval() graph = None + export_kwargs = {} + if getattr(self, "export_kwargs", None): + export_kwargs = self.export_kwargs if isinstance(self.example, dict): pt_res = model_obj(**self.example) - graph = export(model_obj, args=tuple(), kwargs=self.example) + graph = export(model_obj, args=tuple(), kwargs=self.example, **export_kwargs) else: pt_res = model_obj(*self.example) - graph = export(model_obj, self.example) + graph = export(model_obj, self.example, **export_kwargs) ov_model = convert_model(graph, verbose=True) if isinstance(pt_res, dict):