Skip to content

Commit

Permalink
implement replicate{1,2,3} pad
Browse files Browse the repository at this point in the history
Signed-off-by: 11happy <[email protected]>
  • Loading branch information
11happy committed Jan 4, 2025
1 parent a091846 commit dd2430a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/frontends/pytorch/src/op/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ OutputVector translate_reflection_pad_nd_fx(const NodeContext& context) {
return translate_pad_common(context, data, paddings, pad_value, "reflect");
}

OutputVector translate_replicate_pad_nd_fx{const NodeContext & context} {
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
auto paddings = context.const_input<std::vector<int64_t>>(1);
Output<Node> pad_value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
return translate_pad_common(context, data, paddings, pad_value, "replicate");
}

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ OP_CONVERTER(translate_new_zeros_fx);
OP_CONVERTER(translate_ones_fx);
OP_CONVERTER(translate_ones_like_fx);
OP_CONVERTER(translate_reflection_pad_nd_fx);
OP_CONVERTER(translate_replicate_pad_nd_fx);
OP_CONVERTER(translate_reshape_fx);
OP_CONVERTER(translate_rsub_fx);
OP_CONVERTER(translate_scalar_tensor_fx);
Expand Down Expand Up @@ -930,6 +931,9 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.reflection_pad1d.default", op::translate_reflection_pad_nd_fx},
{"aten.reflection_pad2d.default", op::translate_reflection_pad_nd_fx},
{"aten.reflection_pad3d.default", op::translate_reflection_pad_nd_fx},
{"aten.replicate_pad1d.default", op::translate_replicate_pad_nd_fx},
{"aten.replicate_pad2d.default", op::translate_replicate_pad_nd_fx},
{"aten.replicate_pad3d.default", op::translate_replicate_pad_nd_fx},
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.repeat.default", op::translate_1to1_match_2_inputs<opset10::Tile>},
Expand Down
45 changes: 43 additions & 2 deletions tests/layer_tests/pytorch_tests/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def __init__(self, pads):
if ndim == 1:
self.pad = torch.nn.ReflectionPad1d(pads)
elif ndim == 2:
self.pad = torch.nn.ReflectionPad1d(pads)
self.pad = torch.nn.ReflectionPad2d(pads)
elif ndim == 3:
self.pad = torch.nn.ReflectionPad1d(pads)
self.pad = torch.nn.ReflectionPad3d(pads)
else:
raise Exception("Unsupported pads")

Expand All @@ -244,3 +244,44 @@ def test_reflection_padnd(self, pads, dtype, ie_device, precision, ir_version):
print(ndim)
self._test(*self.create_model(pads), ie_device, precision, ir_version,
kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype})

class TestReplicatePad(PytorchLayerTest):
def _prepare_input(self, ndim=4, dtype="float32"):
import numpy as np
input_5d_shape = [5,9,1,1,2,4]
return (np.random.randn(*input_5d_shape[:ndim]).astype(dtype),)

def create_model(self, pads):
import torch
import torch.nn.functional as F

class aten_pad(torch.nn.Module):
def __init__(self, pads):
super().__init__()
ndim = len(pads) / 2
if ndim == 1:
self.pad = torch.nn.ReplicationPad1d(pads)
elif ndim == 2:
self.pad = torch.nn.ReplicationPad2d(pads)
elif ndim == 3:
self.pad = torch.nn.ReplicationPad3d(pads)
else:
raise Exception("Unsupported pads")

def forward(self, x):
return self.pad(x)

return aten_pad(pads), None, "aten::pad"

@pytest.mark.parametrize("dtype", ["float32", "float64", "int32"])
@pytest.mark.parametrize("pads", [
(1, 2),
(1, 2, 3, 4),
(1, 2, 3, 4, 3, 2),
])
@pytest.mark.nightly
@pytest.mark.precommit_torch_export
def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version):
ndim = len(pads) // 2 + 2
self._test(*self.create_model(pads), ie_device, precision, ir_version,
kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype})

0 comments on commit dd2430a

Please sign in to comment.