-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PT FE] Add aten::rot90 #28224
base: master
Are you sure you want to change the base?
[PT FE] Add aten::rot90 #28224
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,67 @@ | ||||
// Copyright (C) 2018-2024 Intel Corporation | ||||
// SPDX-License-Identifier: Apache-2.0 | ||||
// | ||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp" | ||||
#include "openvino/op/constant.hpp" | ||||
#include "openvino/op/transpose.hpp" | ||||
#include "utils.hpp" | ||||
|
||||
namespace ov { | ||||
namespace frontend { | ||||
namespace pytorch { | ||||
namespace op { | ||||
|
||||
using namespace ov::op; | ||||
|
||||
OutputVector translate_rot90(const NodeContext& context) { | ||||
num_inputs_check(context, 1, 3); | ||||
auto input = context.get_input(0); | ||||
int k = context.input_is_none(1) ? 1 : context.const_input<int64_t>(1); | ||||
std::vector<int64_t> dims = context.input_is_none(2) ? std::vector<int64_t>{0, 1} | ||||
: context.const_input<std::vector<int64_t>>(2); | ||||
const auto& partial_shape = input.get_partial_shape(); | ||||
const auto ndims = partial_shape.rank().get_length(); | ||||
|
||||
PYTORCH_OP_CONVERSION_CHECK(dims.size() == 2, | ||||
"Expected total rotation dims == 2, but got dims = ", | ||||
dims.size()); | ||||
PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, | ||||
"Expected total dims >= 2, but got total dims = ", | ||||
ndims); | ||||
PYTORCH_OP_CONVERSION_CHECK(dims[0] != dims[1], | ||||
"Rotation dimensions must be different, but got dim0 = " + | ||||
std::to_string(dims[0]) + " and dim1 = " + std::to_string(dims[1])); | ||||
|
||||
for (auto& dim : dims) { | ||||
dim = (dim + ndims) % ndims; | ||||
} | ||||
Comment on lines
+36
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If i change dims to be a Node, i am not sure how i can extract individual values from dims to be passed in the k==1 or k==3 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure what
|
||||
|
||||
k = k % 4; | ||||
Output<Node> rotated; | ||||
|
||||
if (k == 1 || k == 3) { | ||||
int64_t flip_dim = (k == 1) ? dims[1] : dims[0]; | ||||
auto flip_dims = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {flip_dim})); | ||||
auto flipped = create_flip(input, flip_dims); | ||||
std::vector<int64_t> perm_values(ndims); | ||||
std::iota(perm_values.begin(), perm_values.end(), 0); | ||||
std::swap(perm_values[dims[0]], perm_values[dims[1]]); | ||||
auto perm = context.mark_node( | ||||
v0::Constant::create(element::i32, Shape{static_cast<size_t>(ndims)}, perm_values)); | ||||
rotated = context.mark_node(std::make_shared<v1::Transpose>(flipped, perm)); | ||||
} else if (k == 2) { | ||||
size_t dims_size = dims.size(); | ||||
auto flip_dims = context.mark_node(v0::Constant::create(element::i32, Shape{dims_size}, dims)); | ||||
rotated = create_flip(input, flip_dims); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails to build with error:
There is no such function as |
||||
} else { | ||||
rotated = input; | ||||
} | ||||
|
||||
return {rotated}; | ||||
}; | ||||
|
||||
} // namespace op | ||||
} // namespace pytorch | ||||
} // namespace frontend | ||||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestRot90(PytorchLayerTest): | ||
def _prepare_input(self): | ||
|
||
x = np.arange(24).reshape(2, 3, 4).astype(np.float32) | ||
return (x,) | ||
|
||
def create_model(self, k, dims): | ||
import torch | ||
|
||
class aten_rot90(torch.nn.Module): | ||
def __init__(self, k=1, dims=(0, 1)): | ||
super(aten_rot90, self).__init__() | ||
self.k = k | ||
self.dims = dims | ||
|
||
def forward(self, x): | ||
return torch.rot90(x, self.k, self.dims) | ||
|
||
ref_net = None | ||
return aten_rot90(k, dims), ref_net, "aten::rot90" | ||
|
||
@pytest.mark.parametrize("k", [1, 2, 3, 4, 5]) | ||
@pytest.mark.parametrize("dims", [(0, 1), (0, 2), (1, 2)]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.precommit_torch_export | ||
def test_rot90(self, k, dims, ie_device, precision, ir_version): | ||
self._test(*self.create_model(k, dims), ie_device, precision, ir_version, | ||
trace_model=True,dynamic_shapes=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of
const_input
getdims
as Node usingget_input
and use as is. To avoid requirement for it to be const and to avoid creating aConstant
with dims later.