Skip to content

Commit

Permalink
[TF FE] Support StringSplitV2 operation (openvinotoolkit#59)
Browse files Browse the repository at this point in the history
* [TF FE] Support StringSplitV2 operation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update src/tensorflow_translators.cpp

Co-authored-by: Artur Paniukov <[email protected]>

* Update src/tensorflow_translators.cpp

* Handle maxplit attribute correctly

Signed-off-by: Kazantsev, Roman <[email protected]>

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
Co-authored-by: Artur Paniukov <[email protected]>
  • Loading branch information
rkazants and apaniukov authored Mar 6, 2024
1 parent 04143ac commit 843d8fc
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StringLower", translate_string_lower), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StaticRegexReplace", translate_static_regex_replace), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFind", translate_lookup_table_find_op), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFindV2", translate_lookup_table_find_op)
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFindV2", translate_lookup_table_find_op), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StringSplitV2", translate_string_split)
#else
#define OPENVINO_TOKENIZERS_TENSORFLOW_CONVERSION_EXTENSIONS
#endif
Expand Down
61 changes: 61 additions & 0 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "sentence_piece.hpp"
#include "case_fold.hpp"
#include "normalize_unicode.hpp"
#include "ragged_to_sparse.hpp"
#include "regex_normalization.hpp"
#include "regex_split.hpp"
#include "vocab_encoder.hpp"
Expand Down Expand Up @@ -256,3 +257,63 @@ OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::Node

return { lookup_values };
}

ov::OutputVector translate_string_split(const ov::frontend::NodeContext& node) {
auto node_name = node.get_name();
FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "StringSplitV2 expects two inputs (1D input and separator)");
auto input = node.get_input(0);
ov::OutputVector unpacked_input = pre_translate_string_tensor_input(input);
auto sep_const = ov::as_type_ptr<Constant>(node.get_input(1).get_node_shared_ptr());
TENSORFLOW_OP_VALIDATION(node, sep_const, "[TensorFlow Frontend] internal error: only constant separator is supported for StringSplitV2");
auto sep_value = sep_const->cast_vector<std::string>();
TENSORFLOW_OP_VALIDATION(node, sep_value.size() == 1, "[TensorFlow Frontend] inconsistent model: separator must be a scalar");
auto sep = std::make_shared<Constant>(element::u8, Shape{ sep_value[0].length() }, (const void*)sep_value[0].data())->output(0);
if (sep_value[0] == "") {
// default case that means string elements will be removed from leading and trailing white-space
std::string pattern_value = "^\\s+|\\s+$";
auto pattern_constant = std::make_shared<Constant>(element::u8, Shape{ pattern_value.length() }, (const void*)pattern_value.data());
std::string rewrite_value = "";
auto rewrite_constant = std::make_shared<Constant>(element::u8, Shape{ rewrite_value.length() }, (const void*)rewrite_value.data());
ov::OutputVector inputs = unpacked_input;
inputs.push_back(pattern_constant);
inputs.push_back(rewrite_constant);
unpacked_input = std::make_shared<RegexNormalization>(inputs, true)->outputs();
std::string new_sep_value = "[\\s\\p{Zs}]+";
sep = std::make_shared<Constant>(element::u8, Shape{ new_sep_value.length() }, (const void*)new_sep_value.data());
}
auto maxsplit = node.get_attribute<int64_t>("maxsplit", -1);
TENSORFLOW_OP_VALIDATION(node, maxsplit == -1, "[TensorFlow Frontend] internal error: only maxsplit equal to -1 is supported for StringSplitV2");

// compute batch_dim to generate ragged_begins and ragged_ends for RegexSplit
auto input_shape = std::make_shared<ShapeOf>(input, element::i32);
auto squeeze_axis = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{0});
auto batch_dim = std::make_shared<Squeeze>(input_shape, squeeze_axis);
auto zero_const = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto one_const = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
auto ragged_begins = std::make_shared<Range>(zero_const, batch_dim, one_const, element::i32);
auto ragged_ends = std::make_shared<Add>(ragged_begins, one_const);

auto regex_split_outputs = std::make_shared<RegexSplit>(ov::OutputVector{ ragged_begins, ragged_ends, unpacked_input[0],
unpacked_input[1], unpacked_input[2], sep }, nullptr, "remove", false, maxsplit)->outputs();


// compute sparse tensor indices
auto indices = std::make_shared<RaggedToSparse>(ov::OutputVector{ regex_split_outputs[0], regex_split_outputs[1] })->output(0);
indices = std::make_shared<Convert>(indices, element::i64);
indices.set_names({ node_name + ":0" });

// compute values of Sparse Tensor of ov::element::string type
auto values = post_translate_string_tensor_output(ov::OutputVector{ regex_split_outputs[2], regex_split_outputs[3], regex_split_outputs[4] });
values.set_names({ node_name + ":1" });

// compute a shape of output tensor in a dense form
// compute maximum number of string elements per batch in output tensor after split
auto max_num_per_batch = std::make_shared<Subtract>(regex_split_outputs[1], regex_split_outputs[0])->output(0);
auto reduction_axes = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{0});
max_num_per_batch = std::make_shared<ReduceMax>(max_num_per_batch, reduction_axes, true);
auto shape = std::make_shared<Concat>(ov::OutputVector{ input_shape, max_num_per_batch }, 0)->output(0);
shape = std::make_shared<Convert>(shape, element::i64);
shape.set_names({ node_name + ":2" });

return ov::OutputVector{ indices, values, shape };
}
1 change: 1 addition & 0 deletions src/tensorflow_translators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifdef OpenVINO_Frontend_TensorFlow_FOUND
#include <openvino/frontend/tensorflow/node_context.hpp>
ov::OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::NodeContext& node);
ov::OutputVector translate_string_split(const ov::frontend::NodeContext& node);
#endif

ov::OutputVector translate_sentencepiece_op(const ov::frontend::NodeContext& node);
Expand Down

0 comments on commit 843d8fc

Please sign in to comment.