diff --git a/src/frontends/pytorch/src/op/rot90.cpp b/src/frontends/pytorch/src/op/rot90.cpp new file mode 100644 index 00000000000000..4790d28600b539 --- /dev/null +++ b/src/frontends/pytorch/src/op/rot90.cpp @@ -0,0 +1,91 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/core/validation_util.hpp" +#include "openvino/op/shape_of.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_rot90(const NodeContext& context) { + num_inputs_check(context, 1, 3); + auto input = context.get_input(0); + int k = context.input_is_none(1) ? 1 : context.const_input(1); + auto dims = context.input_is_none(2) + ? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0,1})) + : get_input_as_i32(context, 2); + const auto& partial_shape = input.get_partial_shape(); + const auto ndims = partial_shape.rank().get_length(); + + std::shared_ptr rank = std::make_shared( + ov::element::i32, ov::Shape{}, std::vector{static_cast(ndims)}); + auto dims_norm = normalize_axis(context, dims, rank); + auto dims_const = std::dynamic_pointer_cast(dims_norm.get_node_shared_ptr()); + auto dims_values = dims_const->cast_vector(); + + auto start = v0::Constant::create(element::i32, {}, {0}); + auto step = v0::Constant::create(element::i32, {}, {1}); + auto range = std::make_shared(start, rank, step, element::i32); + + auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto dim0_node = std::make_shared( + v0::Constant::create(element::i32, {}, {dims_values[0]}), axis_0); + auto dim1_node = std::make_shared( + v0::Constant::create(element::i32, {}, {dims_values[1]}), axis_0); + + auto indices = std::make_shared(OutputVector{dim0_node, dim1_node}, 0); + auto updates = std::make_shared( + OutputVector{dim1_node, dim0_node}, 0); + + Output scatter = std::make_shared( + range, indices, updates, axis_0); + if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) { + scatter = context.mark_node(scatter_const); + } else { + context.mark_nodes( + {start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()}); + } + + PYTORCH_OP_CONVERSION_CHECK(dims_values.size() == 2, + "Expected total rotation dims == 2, but got dims = ", + dims_values.size()); + PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, + "Expected total dims >= 2, but got total dims = ", + ndims); + PYTORCH_OP_CONVERSION_CHECK(dims_values[0] != dims_values[1], + "Rotation dimensions must be different, but got dim0 = " + + std::to_string(dims_values[0]) + " and dim1 = " + std::to_string(dims_values[1])); + + k = k % 4; + Output rotated; + + if (k == 1 || k == 3) { + Output flip_dims = (k ==1) ? dim1_node : dim0_node; + auto flipped = create_flip(input, flip_dims); + rotated = context.mark_node(std::make_shared(flipped, scatter)); + } else if (k == 2) { + rotated = create_flip(input, dims_norm); + } else { + rotated = input; + } + + return {rotated}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a73c13814d7663..2ccaabe186d0e7 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -200,6 +200,7 @@ OP_CONVERTER(translate_reshape_as); OP_CONVERTER(translate_rnn); OP_CONVERTER(translate_roi_align); OP_CONVERTER(translate_roll); +OP_CONVERTER(translate_rot90); OP_CONVERTER(translate_round); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_rsub); @@ -624,6 +625,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::rnn_relu", op::translate_rnn}, {"aten::rnn_tanh", op::translate_rnn}, {"aten::roll", op::translate_roll}, + {"aten::rot90", op::translate_rot90}, {"aten::round", op::translate_round}, {"aten::rsqrt", op::optional_out}, {"aten::rsqrt_", op::inplace_op}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 5cc7ec21f30911..a48fce79bb73e3 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -167,6 +167,15 @@ Output normalize_axis(const NodeContext& context, const Output& axis } } +Output create_flip(const Output& x, const Output& axis) { + auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1}); + auto minimum_int = v0::Constant::create(element::i32, Shape{}, {std::numeric_limits::min()}); + auto axis_shape = std::make_shared(axis, element::i32); + auto start = std::make_shared(minus_one, axis_shape); + auto stop = std::make_shared(minimum_int, axis_shape); + return std::make_shared(x, start, stop, start, axis); +}; + std::shared_ptr numel(const NodeContext& context, const Output& x, element::Type output_type) { auto input_shape = context.mark_node(std::make_shared(x, output_type)); auto axes = context.mark_node(v0::Constant::create(output_type, Shape({1}), {0})); diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 9346b9e18b94a3..024e9158e9cdcd 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -59,6 +59,8 @@ std::shared_ptr get_node_axes_range(const NodeContext& context, const Outp Output normalize_axis(const NodeContext& context, const Output& axis, const Output& input_node); +Output create_flip(const Output& x, const Output& axis); + std::shared_ptr numel(const NodeContext& context, const Output& x, element::Type output_type = element::i32); diff --git a/tests/layer_tests/pytorch_tests/test_rot90.py b/tests/layer_tests/pytorch_tests/test_rot90.py new file mode 100644 index 00000000000000..ac369353262746 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_rot90.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRot90(PytorchLayerTest): + def _prepare_input(self): + + x = np.arange(24).reshape(2, 3, 4).astype(np.float32) + return (x,) + + def create_model(self, k, dims): + import torch + + class aten_rot90(torch.nn.Module): + def __init__(self, k=1, dims=(0, 1)): + super(aten_rot90, self).__init__() + self.k = k + self.dims = dims + + def forward(self, x): + return torch.rot90(x, self.k, self.dims) + + ref_net = None + return aten_rot90(k, dims), ref_net, "aten::rot90" + + @pytest.mark.parametrize("k", [1, 2, 3, 4, 5]) + @pytest.mark.parametrize("dims", [(0, 1), (0, 2), (1, 2)]) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + def test_rot90(self, k, dims, ie_device, precision, ir_version): + self._test(*self.create_model(k, dims), ie_device, precision, ir_version, + trace_model=True,dynamic_shapes=False) \ No newline at end of file