Skip to content

Commit

Permalink
[PT FE] Support dynamic shapes torch.export
Browse files Browse the repository at this point in the history
Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin committed Jan 7, 2025
1 parent 26e5fe9 commit 043bc89
Show file tree
Hide file tree
Showing 21 changed files with 416 additions and 387 deletions.
374 changes: 225 additions & 149 deletions src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,12 @@ def may_produce_alias(self, in_index: int, out_index: int) -> bool:
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
return False

def inlined_input(self, index):
return []

def is_input_inlined(self, index):
return False

def get_inlined_input_decoder(self, index):
return None

def get_attribute(self, name):
return OVAny(None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, may_produce_alias, in_index, out_index);
}

ov::OutputVector inlined_input(size_t index) const override {
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_input, index);
}

bool is_input_inlined(size_t index) const override {
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, is_input_inlined, index);
}

std::shared_ptr<TorchDecoder> get_inlined_input_decoder(size_t index) const override {
PYBIND11_OVERRIDE_PURE(std::shared_ptr<TorchDecoder>, TorchDecoder, get_inlined_input_decoder, index);
}

ov::Any get_attribute(const std::string &name) const override{
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_attribute, name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,13 @@ class TorchDecoder : public IDecoder {
/// \brief Returns if output may contain alias of input in AliasDB
virtual bool may_produce_alias(size_t in_index, size_t out_index) const = 0;

/// Returns new nodes for inputs inlined in the op itself
// Used in Torch.FX decoder
virtual OutputVector inlined_input(size_t index) const = 0;

/// Returns if input is inlined
// Used in Torch.FX decoder
virtual bool is_input_inlined(size_t index) const = 0;

/// Return decoder for inlined input
virtual std::shared_ptr<TorchDecoder> get_inlined_input_decoder(size_t index) const = 0;

/// Returns named attribute as Any. For example kwargs input for FX graph
virtual ov::Any get_attribute(const std::string& name) const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,44 +51,9 @@ class NodeContext : public frontend::NodeContext {

// Search for input in tensor map and return an output port for already converted op
// TODO: int due to base class uses it, but naturally it should be size_t for PT
Output<Node> get_input(int index) const override {
size_t index_ = static_cast<size_t>(index);
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index_),
"Input doesn't exist with index: ",
index,
" for operation ",
get_op_type());
auto input = m_decoder_inputs.at(index);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(index_)) {
auto inlined_input = m_decoder->inlined_input(index_);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
index,
" for operation ",
get_op_type());
return inlined_input[0];
}
}
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
return m_tensor_map->at(input);
}
Output<Node> get_input(int index) const override;

Output<Node> get_input(const std::string& name) const override {
FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist");
auto attr = get_attribute_as_any(name);
if (attr.is<Output<Node>>()) {
// Case when input is constant value
return attr.as<Output<Node>>();
} else if (attr.is<type::PyNone>()) {
// None means input is unknown type, most likely a Node
auto input = m_decoder->get_named_input(name);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
return m_tensor_map->at(input);
}
FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node.");
}
Output<Node> get_input(const std::string& name) const override;

Any get_values_from_const_input(int index) const override;

Expand Down
86 changes: 56 additions & 30 deletions src/frontends/pytorch/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,39 +145,65 @@ std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
return model;
}

Output<Node> NodeContext::get_input(int index) const {
size_t index_ = static_cast<size_t>(index);
auto input = m_decoder_inputs.at(index);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(index_)) {
if (m_decoder->input_is_none(index_)) {
// some operations like aten.index.Tensor can have None inputs
auto dummy_decoder = std::make_shared<InternalOpDecoder>("torch::None", 1);
auto fw_node = std::make_shared<PtFrameworkNode>(dummy_decoder, OutputVector{});
auto attrs = fw_node->get_attrs();
attrs["none_value"] = "";
attrs[PtFrameworkNode::failed_conversion_key] =
"None constant cannot be converted to OpenVINO opset and should be removed by consuming "
"operation.";
fw_node->set_attrs(attrs);
return fw_node->output(0);
} else {
auto inlined_decoder = m_decoder->get_inlined_input_decoder(index_);
auto inlined_ctx = NodeContext(inlined_decoder,
m_ext_tensor_map,
m_tensor_map,
m_external_parameters,
m_mutated_tensors,
m_translate_session);
auto inlined_input = m_translate_session->convert_node(inlined_ctx);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
index,
" for operation ",
get_op_type());
return inlined_input[0];
}
}
}
auto tensor_it = m_tensor_map->find(input);
FRONT_END_GENERAL_CHECK(tensor_it != m_tensor_map->end(), "No tensor corresponding input: ", input, " exist.");
return tensor_it->second;
}

Output<Node> NodeContext::get_input(const std::string& name) const {
FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist");
auto attr = get_attribute_as_any(name);
if (attr.is<Output<Node>>()) {
// Case when input is constant value
return attr.as<Output<Node>>();
} else if (attr.is<type::PyNone>()) {
// None means input is unknown type, most likely a Node
auto input = m_decoder->get_named_input(name);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
return m_tensor_map->at(input);
}
FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node.");
}

OutputVector NodeContext::inputs() const {
OutputVector res;
for (size_t i = 0; i < m_decoder_inputs.size(); i++) {
auto input = m_decoder_inputs.at(i);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(i)) {
if (input_is_none(i)) {
// some operations like aten.index.Tensor can have None inputs
auto dummy_decoder = std::make_shared<InternalOpDecoder>("torch::None", 1);
auto fw_node = std::make_shared<PtFrameworkNode>(dummy_decoder, OutputVector{});
auto attrs = fw_node->get_attrs();
attrs["none_value"] = "";
attrs[PtFrameworkNode::failed_conversion_key] =
"None constant cannot be converted to OpenVINO opset and should be removed by consuming "
"operation.";
fw_node->set_attrs(attrs);
res.push_back(fw_node->output(0));
} else {
auto inlined_input = m_decoder->inlined_input(i);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
i,
" for operation ",
get_op_type());
res.push_back(inlined_input[0]);
}
continue;
}
}
auto tensor_it = m_tensor_map->find(input);
FRONT_END_GENERAL_CHECK(tensor_it != m_tensor_map->end(), "No tensor corresponding input: ", input, " exist.");
res.push_back(tensor_it->second);
res.push_back(get_input(i));
}
return res;
}
Expand Down
77 changes: 45 additions & 32 deletions src/frontends/pytorch/src/op/as_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@ namespace pytorch {
namespace op {

using namespace ov::op;

namespace {
bool compare_strides(const std::tuple<size_t, size_t>& a, const std::tuple<size_t, size_t>& b) {
return std::get<0>(a) > std::get<0>(b);
}
OutputVector translate_as_strided(const NodeContext& context) {
// "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"
num_inputs_check(context, 3, 4);
auto decoder = context.get_decoder();
auto input = context.get_input(0);

OutputVector translate_as_strided_common(const NodeContext& context,
const Output<Node>& input,
const std::vector<size_t>& input_strides,
const std::deque<Output<Node>>& sizes,
const std::deque<Output<Node>>& strides) {
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto input_strides = decoder->get_input_strides(0);
PYTORCH_OP_CONVERSION_CHECK(input_strides.size() != 0,
"aten::as_strided: Couldn't retrieve input stride information from torchscript.");

std::vector<size_t> idxs(input_strides.size());
iota(idxs.begin(), idxs.end(), 0);
std::vector<std::tuple<size_t, size_t>> stride_idxs(idxs.size());
Expand All @@ -53,26 +52,6 @@ OutputVector translate_as_strided(const NodeContext& context) {
context.mark_node(v0::Constant::create(element::i32, Shape{transpose_idx.size()}, transpose_idx));
auto transposed_input = context.mark_node(std::make_shared<v1::Transpose>(input, transpose_idx_const));
auto flat_input = context.mark_node(std::make_shared<v1::Reshape>(transposed_input, const_neg_1, false));
std::deque<Output<Node>> sizes;
std::deque<Output<Node>> strides;
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(1).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(1);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
sizes.push_front(const_input);
});
} else {
sizes = get_list_as_outputs(context.get_input(1));
}
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(2);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
strides.push_front(const_input);
});
} else {
strides = get_list_as_outputs(context.get_input(2));
}
auto offset = const_0->output(0);
if (!context.input_is_none(3)) {
offset = get_input_as_i32(context, 3);
Expand All @@ -84,12 +63,12 @@ OutputVector translate_as_strided(const NodeContext& context) {
auto strides_length_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides.size()}));
auto ones_strides_len = context.mark_node(std::make_shared<v0::Tile>(const_1, strides_length_const));
auto indices = const_0;
std::for_each(strides.rbegin(), strides.rend(), [&](Output<Node>& stride) {
std::for_each(strides.rbegin(), strides.rend(), [&](const Output<Node>& stride) {
auto const_num_iter = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides_size - i}));
stride = context.mark_node(std::make_shared<v0::Convert>(stride, element::i32));
auto stride_conv = context.mark_node(std::make_shared<v0::Convert>(stride, element::i32));
auto size = sizes.at(strides_size - i);
auto range = context.mark_node(std::make_shared<v4::Range>(const_0, size, const_1, element::i32));
range = context.mark_node(std::make_shared<v1::Multiply>(range, stride));
range = context.mark_node(std::make_shared<v1::Multiply>(range, stride_conv));
auto iteration_shape = context.mark_node(
std::make_shared<v3::ScatterUpdate>(ones_strides_len, const_num_iter, const_neg_1, const_0));
range = context.mark_node(std::make_shared<v1::Reshape>(range, iteration_shape, false));
Expand All @@ -99,7 +78,41 @@ OutputVector translate_as_strided(const NodeContext& context) {
indices = context.mark_node(std::make_shared<v1::Add>(indices, offset));
auto gather = context.mark_node(std::make_shared<v8::Gather>(flat_input, indices, const_0));
return {gather};
}
} // namespace

OutputVector translate_as_strided(const NodeContext& context) {
// "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"
num_inputs_check(context, 3, 4);
auto decoder = context.get_decoder();
auto input = context.get_input(0);
auto input_strides = decoder->get_input_strides(0);
PYTORCH_OP_CONVERSION_CHECK(input_strides.size() != 0,
"aten::as_strided: Couldn't retrieve input stride information from torchscript.");

std::deque<Output<Node>> sizes;
std::deque<Output<Node>> strides;
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(1).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(1);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
sizes.push_front(const_input);
});
} else {
sizes = get_list_as_outputs(context.get_input(1));
}
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(2);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
strides.push_front(const_input);
});
} else {
strides = get_list_as_outputs(context.get_input(2));
}
return translate_as_strided_common(context, input, input_strides, sizes, strides);
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
15 changes: 4 additions & 11 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,11 @@ OutputVector translate_cat(const NodeContext& context) {
};

OutputVector translate_cat_fx(const NodeContext& context) {
// This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 1, context.get_input_size());
std::deque<Output<Node>> list_elems;
for (size_t i = 0; i < context.get_input_size() - 1; i++) {
list_elems.push_back(context.get_input(static_cast<int>(i)));
}
num_inputs_check(context, 1, 2);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
int64_t axis = 0;
if (!context.get_input_type(context.get_input_size() - 1).is<type::List>()) {
// axis can be not present and that means that last input will have List type
axis = context.const_input<int64_t>(context.get_input_size() - 1);
} else {
list_elems.push_back(context.get_input(static_cast<int>(context.get_input_size() - 1)));
if (!context.input_is_none(1)) {
axis = context.const_input<int64_t>(1);
}
return translate_cat_common(context, list_elems, axis, true);
};
Expand Down
18 changes: 0 additions & 18 deletions src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,6 @@ OutputVector translate_expand_as(const NodeContext& context) {
return {context.mark_node(std::make_shared<v3::Broadcast>(x, shape, BroadcastType::BIDIRECTIONAL))};
};

OutputVector translate_expand_fx(const NodeContext& context) {
auto num_inputs = context.get_input_size();
num_inputs_check(context, 2, num_inputs);
auto x = context.get_input(0);
std::vector<int32_t> shape_vec;
if (context.get_input_type(1).is<type::List>()) {
auto concat = concat_list_from_inputs(context, 1, num_inputs);
return base_expand(context, x, concat);
} else {
auto x = context.get_input(0);
auto sizes = context.get_input(1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
return base_expand(context, x, sizes);
}
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
11 changes: 3 additions & 8 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,9 @@ OutputVector translate_full_fx(const NodeContext& context) {
// aten.full.default([16, 16], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'),
// pin_memory = False)
auto num_inputs = context.get_input_size();
num_inputs_check(context, 2, num_inputs);
ov::Output<ov::Node> sizes;
if (context.get_input_type(0).is<type::List>()) {
sizes = concat_list_from_inputs(context, 0, num_inputs - 1);
} else {
sizes = context.get_input(0);
}
auto value = context.get_input(static_cast<int>(num_inputs - 1));
num_inputs_check(context, 2, 2);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.get_input(1);

auto filled_tensor = base_translate_full(context, sizes, value);
if (context.has_attribute("dtype")) {
Expand Down
Loading

0 comments on commit 043bc89

Please sign in to comment.