Skip to content

Commit

Permalink
[TF FE] Add conversion extension for LookupTableFind operation (openv…
Browse files Browse the repository at this point in the history
…inotoolkit#47)

* [TF FE] Add conversion extension for LookupTableFind operation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update src/ov_extension.cpp

* Update src/ov_extension.cpp

* Update src/tensorflow_translators.hpp

* Update src/tensorflow_translators.hpp

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Mar 1, 2024
1 parent c89b9fa commit 35d7dcb
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("SentencepieceOp", translate_sentencepiece_op), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("RaggedTensorToSparse", translate_sentencepiece_tokenizer), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StringLower", translate_string_lower), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StaticRegexReplace", translate_static_regex_replace),
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StaticRegexReplace", translate_static_regex_replace), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFind", translate_lookup_table_find_op), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFindV2", translate_lookup_table_find_op)
#else
#define OPENVINO_TOKENIZERS_TENSORFLOW_CONVERSION_EXTENSIONS
#endif
Expand Down
70 changes: 70 additions & 0 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/tensorflow/hash_table.hpp"

#include "openvino/op/util/framework_node.hpp"
#include "openvino/opsets/opset13.hpp"

Expand Down Expand Up @@ -161,3 +163,71 @@ ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node) {
set_node_name(node_name, string_lower_result.get_node_shared_ptr());
return { string_lower_result };
}

OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::NodeContext& node) {
FRONT_END_GENERAL_CHECK(node.get_input_size() == 3, "LookupTableFind or LookupTableFindV2 expects 3 inputs");
auto table_handle = as_type_ptr<ov::frontend::tensorflow::HashTable>(node.get_input_by_reference(0).get_node_shared_ptr());
TENSORFLOW_OP_VALIDATION(
node,
table_handle,
"[TensorFlow Frontend] internal error: LookupTableFind operation expects table_handle by the first input");
TENSORFLOW_OP_VALIDATION(
node,
table_handle->is_initialized(),
"[TensorFlow Frontend] internal error: LookupTableFind operation expects initialized table_handle");
auto keys = node.get_input(1);
auto default_value = node.get_input(2);

auto key_type = table_handle->get_key_type();
TENSORFLOW_OP_VALIDATION(
node,
key_type.is_integral_number(),
"[TensorFlow Frontend] internal error: LookupTableFind is only supported for integer keys");

auto all_keys = table_handle->get_keys();
auto all_values = table_handle->get_values();

// reshape both all values and keys to 1D tensor to work it further
auto target_shape = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{-1});
all_keys = std::make_shared<Reshape>(all_keys, target_shape, false);
all_values = std::make_shared<Reshape>(all_values, target_shape, false);

// update all values with default value and all keys
auto default_value_shape = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{1});
default_value = std::make_shared<Reshape>(default_value, default_value_shape, false);
all_values = std::make_shared<Concat>(OutputVector{ all_values, default_value }, 0);
auto num_keys = std::make_shared<ShapeOf>(all_keys, element::i64)->output(0);
auto scalar_shape = std::make_shared<Constant>(element::i32, Shape{ 0 }, std::vector<int32_t>{});
num_keys = std::make_shared<Reshape>(num_keys, scalar_shape, false);
num_keys = std::make_shared<Convert>(num_keys, key_type);

// compute mask which keys are not valid and for which default value must be used
auto unsqueeze_axis = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{-1});
auto unsqueeze_keys = std::make_shared<Unsqueeze>(keys, unsqueeze_axis);
auto equal_mask = std::make_shared<Equal>(all_keys, unsqueeze_keys)->output(0);
auto reduce_equal_mask = std::make_shared<ReduceLogicalOr>(equal_mask, unsqueeze_axis, false);

// map keys to new keys from range [0, n], n index will be for out-of-range keys
// 1. generate mask-01 of shape [keys_shape, len(all_keys)],
// where 0 - not found key, 1 - found key
auto const_zero = std::make_shared<Constant>(key_type, Shape{}, 0);
auto const_one = std::make_shared<Constant>(key_type, Shape{}, 1);
auto mask01 = std::make_shared<Select>(equal_mask, const_one, const_zero);
// 2. generate a range [0, n-1] that will be multiplied to mask for computation of new keys
auto new_all_keys = std::make_shared<Range>(const_zero, num_keys, const_one, key_type);
// 3. compute new keys
auto reduce_axis = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{-1});
auto new_keys = std::make_shared<Multiply>(mask01, new_all_keys)->output(0);
new_keys = std::make_shared<ReduceMax>(new_keys, reduce_axis, false);

// replace invalid keys with key_for_default_value
new_keys = std::make_shared<Select>(reduce_equal_mask, new_keys, num_keys);

// at this point all keys are sorted and are from the range [0, n]
// and keys are also mapped to this range
auto gather_axis = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{0});
auto lookup_values = std::make_shared<Gather>(all_values, new_keys, gather_axis);
set_node_name(node.get_name(), lookup_values);

return { lookup_values };
}
4 changes: 4 additions & 0 deletions src/tensorflow_translators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#pragma once

#include <openvino/frontend/node_context.hpp>
#ifdef OpenVINO_Frontend_TensorFlow_FOUND
#include <openvino/frontend/tensorflow/node_context.hpp>
ov::OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::NodeContext& node);
#endif

ov::OutputVector translate_sentencepiece_op(const ov::frontend::NodeContext& node);
ov::frontend::NamedOutputVector translate_sentencepiece_tokenizer(const ov::frontend::NodeContext& node);
Expand Down

0 comments on commit 35d7dcb

Please sign in to comment.