diff --git a/src/ov_extension.cpp b/src/ov_extension.cpp index c192c47cc..947de141f 100644 --- a/src/ov_extension.cpp +++ b/src/ov_extension.cpp @@ -18,7 +18,9 @@ std::make_shared("SentencepieceOp", translate_sentencepiece_op), \ std::make_shared("RaggedTensorToSparse", translate_sentencepiece_tokenizer), \ std::make_shared("StringLower", translate_string_lower), \ - std::make_shared("StaticRegexReplace", translate_static_regex_replace), + std::make_shared("StaticRegexReplace", translate_static_regex_replace), \ + std::make_shared("LookupTableFind", translate_lookup_table_find_op), \ + std::make_shared("LookupTableFindV2", translate_lookup_table_find_op) #else #define OPENVINO_TOKENIZERS_TENSORFLOW_CONVERSION_EXTENSIONS #endif diff --git a/src/tensorflow_translators.cpp b/src/tensorflow_translators.cpp index 583c7af76..cafe5515e 100644 --- a/src/tensorflow_translators.cpp +++ b/src/tensorflow_translators.cpp @@ -2,6 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/frontend/tensorflow/hash_table.hpp" + #include "openvino/op/util/framework_node.hpp" #include "openvino/opsets/opset13.hpp" @@ -161,3 +163,71 @@ ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node) { set_node_name(node_name, string_lower_result.get_node_shared_ptr()); return { string_lower_result }; } + +OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::NodeContext& node) { + FRONT_END_GENERAL_CHECK(node.get_input_size() == 3, "LookupTableFind or LookupTableFindV2 expects 3 inputs"); + auto table_handle = as_type_ptr(node.get_input_by_reference(0).get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + table_handle, + "[TensorFlow Frontend] internal error: LookupTableFind operation expects table_handle by the first input"); + TENSORFLOW_OP_VALIDATION( + node, + table_handle->is_initialized(), + "[TensorFlow Frontend] internal error: LookupTableFind operation expects initialized table_handle"); + auto keys = node.get_input(1); + auto default_value = node.get_input(2); + + auto key_type = table_handle->get_key_type(); + TENSORFLOW_OP_VALIDATION( + node, + key_type.is_integral_number(), + "[TensorFlow Frontend] internal error: LookupTableFind is only supported for integer keys"); + + auto all_keys = table_handle->get_keys(); + auto all_values = table_handle->get_values(); + + // reshape both all values and keys to 1D tensor to work it further + auto target_shape = std::make_shared(element::i32, Shape{ 1 }, std::vector{-1}); + all_keys = std::make_shared(all_keys, target_shape, false); + all_values = std::make_shared(all_values, target_shape, false); + + // update all values with default value and all keys + auto default_value_shape = std::make_shared(element::i32, Shape{ 1 }, std::vector{1}); + default_value = std::make_shared(default_value, default_value_shape, false); + all_values = std::make_shared(OutputVector{ all_values, default_value }, 0); + auto num_keys = std::make_shared(all_keys, element::i64)->output(0); + auto scalar_shape = std::make_shared(element::i32, Shape{ 0 }, std::vector{}); + num_keys = std::make_shared(num_keys, scalar_shape, false); + num_keys = std::make_shared(num_keys, key_type); + + // compute mask which keys are not valid and for which default value must be used + auto unsqueeze_axis = std::make_shared(element::i32, Shape{ 1 }, std::vector{-1}); + auto unsqueeze_keys = std::make_shared(keys, unsqueeze_axis); + auto equal_mask = std::make_shared(all_keys, unsqueeze_keys)->output(0); + auto reduce_equal_mask = std::make_shared(equal_mask, unsqueeze_axis, false); + + // map keys to new keys from range [0, n], n index will be for out-of-range keys + // 1. generate mask-01 of shape [keys_shape, len(all_keys)], + // where 0 - not found key, 1 - found key + auto const_zero = std::make_shared(key_type, Shape{}, 0); + auto const_one = std::make_shared(key_type, Shape{}, 1); + auto mask01 = std::make_shared(reduce_equal_mask, new_keys, num_keys); + + // at this point all keys are sorted and are from the range [0, n] + // and keys are also mapped to this range + auto gather_axis = std::make_shared(element::i32, Shape{ 1 }, std::vector{0}); + auto lookup_values = std::make_shared(all_values, new_keys, gather_axis); + set_node_name(node.get_name(), lookup_values); + + return { lookup_values }; +} diff --git a/src/tensorflow_translators.hpp b/src/tensorflow_translators.hpp index 941918c4c..c4a2d6af0 100644 --- a/src/tensorflow_translators.hpp +++ b/src/tensorflow_translators.hpp @@ -5,6 +5,10 @@ #pragma once #include +#ifdef OpenVINO_Frontend_TensorFlow_FOUND +#include +ov::OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::NodeContext& node); +#endif ov::OutputVector translate_sentencepiece_op(const ov::frontend::NodeContext& node); ov::frontend::NamedOutputVector translate_sentencepiece_tokenizer(const ov::frontend::NodeContext& node);