Skip to content

Commit

Permalink
[Frontends][TorchFX] torch.ops.aten._unsafe_index support (#27617)
Browse files Browse the repository at this point in the history
### Details:
 - torch.ops.aten._unsafe_index is supported by the Torch frontend
 - To support Yolo(V8 and V11) TorchFX models quantization with NNCF

### Tickets:
- 141640

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
daniil-lyakhov and rkazants authored Nov 25, 2024
1 parent c17f3a5 commit 147d0af
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(self, options):
"torch.ops.aten.hardtanh.default": None,
"torch.ops.aten.hardtanh_.default": None,
"torch.ops.aten.index.Tensor": None,
"torch.ops.aten._unsafe_index.Tensor": None,
"torch.ops.aten.index_select.default": None,
"torch.ops.aten.isfinite.default": None,
"torch.ops.aten.isinf.default": None,
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.hardtanh.default", op::translate_hardtanh},
{"aten.hardtanh_.default", op::inplace_op<op::translate_hardtanh>},
{"aten.index.Tensor", op::translate_index_fx},
{"aten._unsafe_index.Tensor", op::translate_index_fx},
{"aten.index_select.default", op::translate_index_select},
{"aten.isfinite.default", op::translate_1to1_match_1_inputs<opset10::IsFinite>},
{"aten.isinf.default", op::translate_1to1_match_1_inputs<opset10::IsInf>},
Expand Down
11 changes: 7 additions & 4 deletions tests/layer_tests/pytorch_tests/test_index_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _prepare_input(self, input_shape):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32),)

def create_model(self, indices_list):
def create_model(self, indices_list, safe: bool):
import torch

class aten_index_tensor(torch.nn.Module):
Expand All @@ -20,7 +20,9 @@ def __init__(self, indices_list):
self.indices_list = indices_list

def forward(self, x):
return torch.ops.aten.index.Tensor(x, self.indices_list)
if safe:
return torch.ops.aten.index.Tensor(x, self.indices_list)
return torch.ops.aten._unsafe_index.Tensor(x, self.indices_list)

ref_net = None

Expand All @@ -35,15 +37,16 @@ def forward(self, x):

@pytest.mark.nightly
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize('safe', [True, False])
@pytest.mark.parametrize(('input_shape', 'indices_list'), [
([3, 7], [[0], [5, 3, 0]]),
([3, 7, 6], [[0], None, None]),
([3, 7, 6], [[0], None, [5, 0, 3]]),
([3, 7, 6], [[0, 2, 1], None, [5, 0, 3]]),
([3, 7, 6], [[0, 2, 1], [4], [5, 0, 3]]),
])
def test_index_tensor(self, input_shape, indices_list, ie_device, precision, ir_version):
def test_index_tensor(self, safe, input_shape, indices_list, ie_device, precision, ir_version):
if not PytorchLayerTest.use_torch_export():
pytest.skip(reason='aten.index.Tensor test is supported only on torch.export()')
self._test(*self.create_model(indices_list), ie_device, precision, ir_version,
self._test(*self.create_model(indices_list, safe), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_shape': input_shape})

0 comments on commit 147d0af

Please sign in to comment.