From 1ddb12edca0e6a9fb02b37a3e1565111dfd22d98 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 29 Oct 2024 12:30:41 +0400 Subject: [PATCH] Build w/lo FastTokenizers (#305) --- src/charsmap_normalization.cpp | 10 +- src/ov_extension.cpp | 6 +- src/regex_split.cpp | 2 +- src/sentence_piece.cpp | 5 +- src/sentence_piece.hpp | 134 +++++++-------- src/string_to_hash_bucket.cpp | 298 +++++++++++++++++---------------- src/tensorflow_translators.cpp | 41 ++--- src/tensorflow_translators.hpp | 2 +- 8 files changed, 253 insertions(+), 245 deletions(-) diff --git a/src/charsmap_normalization.cpp b/src/charsmap_normalization.cpp index 28c17fcfd..d5ff97395 100644 --- a/src/charsmap_normalization.cpp +++ b/src/charsmap_normalization.cpp @@ -10,10 +10,12 @@ using namespace ov; namespace { - std::shared_ptr make_identity_spec() { - auto spec = sentencepiece::SentencePieceTrainer::GetNormalizerSpec("identity"); - return std::make_shared(spec); - } + +std::shared_ptr make_identity_spec() { + auto spec = sentencepiece::SentencePieceTrainer::GetNormalizerSpec("identity"); + return std::make_shared(spec); +} + } // namespace diff --git a/src/ov_extension.cpp b/src/ov_extension.cpp index 70585e391..7369fe427 100644 --- a/src/ov_extension.cpp +++ b/src/ov_extension.cpp @@ -67,9 +67,9 @@ OPENVINO_CREATE_EXTENSIONS( std::make_shared>(), std::make_shared>(), std::make_shared>(), - std::make_shared>(), - std::make_shared>(), - std::make_shared>(), + std::make_shared>(), + std::make_shared>(), + std::make_shared>(), OPENVINO_TOKENIZERS_FAST_TOKENIZER_BASED_EXTENSIONS OPENVINO_TOKENIZERS_TENSORFLOW_CONVERSION_EXTENSIONS OPENVINO_TOKENIZERS_TENSORFLOW_CONVERSION_EXTENSIONS_FAST_TOKENIZER_BASED diff --git a/src/regex_split.cpp b/src/regex_split.cpp index 0bceaee2a..175aaac71 100644 --- a/src/regex_split.cpp +++ b/src/regex_split.cpp @@ -21,7 +21,7 @@ const std::map split_modes_map = { {"mergedwithnext", RegexSplit::SplitMode::MERGED_WITH_NEXT} }; -} +} // namespace void RegexSplit::compile_pattern_if_necessary(std::string split_pattern) const { m_split_mode = split_modes_map.at(m_behaviour); diff --git a/src/sentence_piece.cpp b/src/sentence_piece.cpp index 2804baaf4..0f05e734d 100644 --- a/src/sentence_piece.cpp +++ b/src/sentence_piece.cpp @@ -15,7 +15,6 @@ #include "utils.hpp" using sentencepiece::SentencePieceProcessor; -using namespace TemplateExtension; using namespace ov; using namespace ov::frontend; using namespace ov::opset13; @@ -55,6 +54,8 @@ int PieceToByte(absl::string_view piece) { } // namespace } // sentencepiece +namespace { + std::string form_extra_options(bool add_bos, bool add_eos, bool reverse) { std::string extra_options = ""; if (add_bos) { @@ -88,6 +89,8 @@ void init_sp_model_in_eval(const TensorVector& inputs, std::shared_ptrLoadFromSerializedProto(model_proto)); } +} // namespace + SentencepieceTokenizer::SentencepieceTokenizer(const OutputVector& args, int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse) : m_sp(std::make_shared()), m_nbest_size(nbest_size), m_alpha(alpha), m_add_bos(add_bos), m_add_eos(add_eos), diff --git a/src/sentence_piece.hpp b/src/sentence_piece.hpp index e56f6a2ca..12888e2d9 100644 --- a/src/sentence_piece.hpp +++ b/src/sentence_piece.hpp @@ -10,96 +10,96 @@ #include "re2/re2.h" namespace sentencepiece { - class SentencePieceProcessor; - int PieceToByte(absl::string_view piece); -} -namespace TemplateExtension { - class SentencepieceTokenizer : public ov::op::Op { - public: - OPENVINO_OP("SentencepieceTokenizer"); +class SentencePieceProcessor; +int PieceToByte(absl::string_view piece); - SentencepieceTokenizer() = default; - SentencepieceTokenizer(const ov::OutputVector& args, int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse); - SentencepieceTokenizer( - const ov::OutputVector& args, - const std::shared_ptr& sp, - const std::shared_ptr& special_tokens_re, - const std::shared_ptr>& special_tokens_map, - int32_t nbest_size, - float alpha, - bool add_bos, - bool add_eos, - bool reverse - ); +} // sentencepiece - bool visit_attributes(ov::AttributeVisitor& visitor) override; +class SentencepieceTokenizer : public ov::op::Op { +public: + OPENVINO_OP("SentencepieceTokenizer"); - void validate_and_infer_types() override; + SentencepieceTokenizer() = default; + SentencepieceTokenizer(const ov::OutputVector& args, int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse); + SentencepieceTokenizer( + const ov::OutputVector& args, + const std::shared_ptr& sp, + const std::shared_ptr& special_tokens_re, + const std::shared_ptr>& special_tokens_map, + int32_t nbest_size, + float alpha, + bool add_bos, + bool add_eos, + bool reverse + ); - std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + bool visit_attributes(ov::AttributeVisitor& visitor) override; - bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + void validate_and_infer_types() override; - bool has_evaluate() const override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - private: - mutable std::shared_ptr m_sp; - mutable std::shared_ptr m_special_tokens_re; - mutable std::shared_ptr> m_special_tokens_map; - mutable std::mutex m_mutex; - int32_t m_nbest_size; - float m_alpha; - bool m_add_bos; - bool m_add_eos; - bool m_reverse; - }; + bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + bool has_evaluate() const override; - class SentencepieceDetokenizer : public ov::op::Op { - public: - OPENVINO_OP("SentencepieceDetokenizer"); +private: + mutable std::shared_ptr m_sp; + mutable std::shared_ptr m_special_tokens_re; + mutable std::shared_ptr> m_special_tokens_map; + mutable std::mutex m_mutex; + int32_t m_nbest_size; + float m_alpha; + bool m_add_bos; + bool m_add_eos; + bool m_reverse; +}; - SentencepieceDetokenizer() = default; - SentencepieceDetokenizer(const ov::OutputVector& args); - SentencepieceDetokenizer(const ov::OutputVector& args, - const std::shared_ptr& sp); - bool visit_attributes(ov::AttributeVisitor& visitor) override; +class SentencepieceDetokenizer : public ov::op::Op { +public: + OPENVINO_OP("SentencepieceDetokenizer"); - void validate_and_infer_types() override; + SentencepieceDetokenizer() = default; + SentencepieceDetokenizer(const ov::OutputVector& args); + SentencepieceDetokenizer(const ov::OutputVector& args, + const std::shared_ptr& sp); - std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + bool visit_attributes(ov::AttributeVisitor& visitor) override; - bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + void validate_and_infer_types() override; - bool has_evaluate() const override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - private: - mutable std::shared_ptr m_sp; - }; + bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + bool has_evaluate() const override; - class SentencepieceStreamDetokenizer : public ov::op::Op { - public: - OPENVINO_OP("SentencepieceStreamDetokenizer"); +private: + mutable std::shared_ptr m_sp; +}; - SentencepieceStreamDetokenizer() = default; - SentencepieceStreamDetokenizer(const ov::OutputVector& args); - SentencepieceStreamDetokenizer(const ov::OutputVector& args, - const std::shared_ptr& sp); - bool visit_attributes(ov::AttributeVisitor& visitor) override; +class SentencepieceStreamDetokenizer : public ov::op::Op { +public: + OPENVINO_OP("SentencepieceStreamDetokenizer"); - void validate_and_infer_types() override; + SentencepieceStreamDetokenizer() = default; + SentencepieceStreamDetokenizer(const ov::OutputVector& args); + SentencepieceStreamDetokenizer(const ov::OutputVector& args, + const std::shared_ptr& sp); - std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + bool visit_attributes(ov::AttributeVisitor& visitor) override; - bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + void validate_and_infer_types() override; - bool has_evaluate() const override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - private: - mutable std::shared_ptr m_sp; - }; -} // namespace TemplateExtension + bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + + bool has_evaluate() const override; + +private: + mutable std::shared_ptr m_sp; +}; diff --git a/src/string_to_hash_bucket.cpp b/src/string_to_hash_bucket.cpp index 0716b6b1f..45cf578f3 100644 --- a/src/string_to_hash_bucket.cpp +++ b/src/string_to_hash_bucket.cpp @@ -8,181 +8,183 @@ using namespace ov; namespace { - static const uint64_t k0 = 0xc3a5c85c97cb3127ULL; - static const uint64_t k1 = 0xb492b66fbe98f273ULL; - static const uint64_t k2 = 0x9ae16a3b2f90404fULL; - - uint64_t hash_len16(uint64_t u, uint64_t v, uint64_t mul) { - uint64_t a = (u ^ v) * mul; - a ^= (a >> 47); - uint64_t b = (v ^ a) * mul; - b ^= (b >> 47); - b *= mul; - return b; - } - inline uint64_t basic_rotate64(uint64_t val, int shift) { - return shift == 0 ? val : ((val >> shift) | (val << (64 - shift))); - } +static const uint64_t k0 = 0xc3a5c85c97cb3127ULL; +static const uint64_t k1 = 0xb492b66fbe98f273ULL; +static const uint64_t k2 = 0x9ae16a3b2f90404fULL; + +uint64_t hash_len16(uint64_t u, uint64_t v, uint64_t mul) { + uint64_t a = (u ^ v) * mul; + a ^= (a >> 47); + uint64_t b = (v ^ a) * mul; + b ^= (b >> 47); + b *= mul; + return b; +} - inline uint64_t fetch(const char* p) { - uint64_t result; - std::memcpy(&result, p, sizeof(result)); - return result; - } +inline uint64_t basic_rotate64(uint64_t val, int shift) { + return shift == 0 ? val : ((val >> shift) | (val << (64 - shift))); +} + +inline uint64_t fetch(const char* p) { + uint64_t result; + std::memcpy(&result, p, sizeof(result)); + return result; +} #if defined(_MSC_VER) - uint64_t rotate(uint64_t val, int shift) { - return sizeof(unsigned long) == sizeof(val) ? _lrotr(val, shift) : basic_rotate64(val, shift); - } +uint64_t rotate(uint64_t val, int shift) { + return sizeof(unsigned long) == sizeof(val) ? _lrotr(val, shift) : basic_rotate64(val, shift); +} #else - uint64_t rotate(uint64_t val, int shift) { - return basic_rotate64(val, shift); - } +uint64_t rotate(uint64_t val, int shift) { + return basic_rotate64(val, shift); +} #endif - uint64_t hash_len17_to_32(const char* s, size_t len) { - uint64_t mul = k2 + len * 2; - uint64_t a = fetch(s) * k1; - uint64_t b = fetch(s + 8); - uint64_t c = fetch(s + len - 8) * mul; - uint64_t d = fetch(s + len - 16) * k2; - return hash_len16(rotate(a + b, 43) + rotate(c, 30) + d, a + rotate(b + k2, 18) + c, mul); - } +uint64_t hash_len17_to_32(const char* s, size_t len) { + uint64_t mul = k2 + len * 2; + uint64_t a = fetch(s) * k1; + uint64_t b = fetch(s + 8); + uint64_t c = fetch(s + len - 8) * mul; + uint64_t d = fetch(s + len - 16) * k2; + return hash_len16(rotate(a + b, 43) + rotate(c, 30) + d, a + rotate(b + k2, 18) + c, mul); +} - inline uint64_t shift_mix(uint64_t val) { - return val ^ (val >> 47); - } +inline uint64_t shift_mix(uint64_t val) { + return val ^ (val >> 47); +} - inline uint32_t fetch32(const char* p) { - uint32_t result; - memcpy(&result, p, sizeof(result)); - return result; - } +inline uint32_t fetch32(const char* p) { + uint32_t result; + memcpy(&result, p, sizeof(result)); + return result; +} - uint64_t hash_len0_to_16(const char* s, size_t len) { - if (len >= 8) { - uint64_t mul = k2 + len * 2; - uint64_t a = fetch(s) + k2; - uint64_t b = fetch(s + len - 8); - uint64_t c = rotate(b, 37) * mul + a; - uint64_t d = (rotate(a, 25) + b) * mul; - return hash_len16(c, d, mul); - } - if (len >= 4) { - uint64_t mul = k2 + len * 2; - uint64_t a = fetch32(s); - return hash_len16(len + (a << 3), fetch32(s + len - 4), mul); - } - if (len > 0) { - uint8_t a = s[0]; - uint8_t b = s[len >> 1]; - uint8_t c = s[len - 1]; - uint32_t y = static_cast(a) + (static_cast(b) << 8); - uint32_t z = len + (static_cast(c) << 2); - return shift_mix(y * k2 ^ z * k0) * k2; - } - return k2; +uint64_t hash_len0_to_16(const char* s, size_t len) { + if (len >= 8) { + uint64_t mul = k2 + len * 2; + uint64_t a = fetch(s) + k2; + uint64_t b = fetch(s + len - 8); + uint64_t c = rotate(b, 37) * mul + a; + uint64_t d = (rotate(a, 25) + b) * mul; + return hash_len16(c, d, mul); } - - uint64_t hash_len33_to_64(const char* s, size_t len) { + if (len >= 4) { uint64_t mul = k2 + len * 2; - uint64_t a = fetch(s) * k2; - uint64_t b = fetch(s + 8); - uint64_t c = fetch(s + len - 8) * mul; - uint64_t d = fetch(s + len - 16) * k2; - uint64_t y = rotate(a + b, 43) + rotate(c, 30) + d; - uint64_t z = hash_len16(y, a + rotate(b + k2, 18) + c, mul); - uint64_t e = fetch(s + 16) * mul; - uint64_t f = fetch(s + 24); - uint64_t g = (y + fetch(s + len - 32)) * mul; - uint64_t h = (z + fetch(s + len - 24)) * mul; - return hash_len16(rotate(e + f, 43) + rotate(g, 30) + h, e + rotate(f + a, 18) + g, mul); + uint64_t a = fetch32(s); + return hash_len16(len + (a << 3), fetch32(s + len - 4), mul); } - - std::pair weak_hash_len32_with_seeds(uint64_t w, - uint64_t x, - uint64_t y, - uint64_t z, - uint64_t a, - uint64_t b) { - a += w; - b = rotate(b + a + z, 21); - uint64_t c = a; - a += x; - a += y; - b += rotate(a, 44); - return std::make_pair(a + z, b + c); + if (len > 0) { + uint8_t a = s[0]; + uint8_t b = s[len >> 1]; + uint8_t c = s[len - 1]; + uint32_t y = static_cast(a) + (static_cast(b) << 8); + uint32_t z = len + (static_cast(c) << 2); + return shift_mix(y * k2 ^ z * k0) * k2; } + return k2; +} - std::pair weak_hash_len32_with_seeds(const char* s, uint64_t a, uint64_t b) { - return weak_hash_len32_with_seeds(fetch(s), fetch(s + 8), fetch(s + 16), fetch(s + 24), a, b); - } +uint64_t hash_len33_to_64(const char* s, size_t len) { + uint64_t mul = k2 + len * 2; + uint64_t a = fetch(s) * k2; + uint64_t b = fetch(s + 8); + uint64_t c = fetch(s + len - 8) * mul; + uint64_t d = fetch(s + len - 16) * k2; + uint64_t y = rotate(a + b, 43) + rotate(c, 30) + d; + uint64_t z = hash_len16(y, a + rotate(b + k2, 18) + c, mul); + uint64_t e = fetch(s + 16) * mul; + uint64_t f = fetch(s + 24); + uint64_t g = (y + fetch(s + len - 32)) * mul; + uint64_t h = (z + fetch(s + len - 24)) * mul; + return hash_len16(rotate(e + f, 43) + rotate(g, 30) + h, e + rotate(f + a, 18) + g, mul); +} + +std::pair weak_hash_len32_with_seeds(uint64_t w, + uint64_t x, + uint64_t y, + uint64_t z, + uint64_t a, + uint64_t b) { + a += w; + b = rotate(b + a + z, 21); + uint64_t c = a; + a += x; + a += y; + b += rotate(a, 44); + return std::make_pair(a + z, b + c); +} - uint64_t hash64(const char* s, size_t len) { - const uint64_t seed = 81; - if (len <= 32) { - if (len <= 16) { - return hash_len0_to_16(s, len); - } - else { - return hash_len17_to_32(s, len); - } +std::pair weak_hash_len32_with_seeds(const char* s, uint64_t a, uint64_t b) { + return weak_hash_len32_with_seeds(fetch(s), fetch(s + 8), fetch(s + 16), fetch(s + 24), a, b); +} + +uint64_t hash64(const char* s, size_t len) { + const uint64_t seed = 81; + if (len <= 32) { + if (len <= 16) { + return hash_len0_to_16(s, len); } - else if (len <= 64) { - return hash_len33_to_64(s, len); + else { + return hash_len17_to_32(s, len); } + } + else if (len <= 64) { + return hash_len33_to_64(s, len); + } - // For strings over 64 bytes we loop. Internal state consists of - // 56 bytes: v, w, x, y, and z. - uint64_t x = seed; - uint64_t y = seed * k1 + 113; - uint64_t z = shift_mix(y * k2 + 113) * k2; - std::pair v = std::make_pair(0, 0); - std::pair w = std::make_pair(0, 0); - x = x * k2 + fetch(s); - - // Set end so that after the loop we have 1 to 64 bytes left to process. - const char* end = s + ((len - 1) / 64) * 64; - const char* last64 = end + ((len - 1) & 63) - 63; - do { - x = rotate(x + y + v.first + fetch(s + 8), 37) * k1; - y = rotate(y + v.second + fetch(s + 48), 42) * k1; - x ^= w.second; - y += v.first + fetch(s + 40); - z = rotate(z + w.first, 33) * k1; - v = weak_hash_len32_with_seeds(s, v.second * k1, x + w.first); - w = weak_hash_len32_with_seeds(s + 32, z + w.second, y + fetch(s + 16)); - std::swap(z, x); - s += 64; - } while (s != end); - uint64_t mul = k1 + ((z & 0xff) << 1); - // Make s point to the last 64 bytes of input. - s = last64; - w.first += ((len - 1) & 63); - v.first += w.first; - w.first += v.first; - x = rotate(x + y + v.first + fetch(s + 8), 37) * mul; - y = rotate(y + v.second + fetch(s + 48), 42) * mul; - x ^= w.second * 9; - y += v.first * 9 + fetch(s + 40); - z = rotate(z + w.first, 33) * mul; - v = weak_hash_len32_with_seeds(s, v.second * mul, x + w.first); + // For strings over 64 bytes we loop. Internal state consists of + // 56 bytes: v, w, x, y, and z. + uint64_t x = seed; + uint64_t y = seed * k1 + 113; + uint64_t z = shift_mix(y * k2 + 113) * k2; + std::pair v = std::make_pair(0, 0); + std::pair w = std::make_pair(0, 0); + x = x * k2 + fetch(s); + + // Set end so that after the loop we have 1 to 64 bytes left to process. + const char* end = s + ((len - 1) / 64) * 64; + const char* last64 = end + ((len - 1) & 63) - 63; + do { + x = rotate(x + y + v.first + fetch(s + 8), 37) * k1; + y = rotate(y + v.second + fetch(s + 48), 42) * k1; + x ^= w.second; + y += v.first + fetch(s + 40); + z = rotate(z + w.first, 33) * k1; + v = weak_hash_len32_with_seeds(s, v.second * k1, x + w.first); w = weak_hash_len32_with_seeds(s + 32, z + w.second, y + fetch(s + 16)); std::swap(z, x); - return hash_len16(hash_len16(v.first, w.first, mul) + shift_mix(y) * k0 + z, - hash_len16(v.second, w.second, mul) + x, - mul); - } + s += 64; + } while (s != end); + uint64_t mul = k1 + ((z & 0xff) << 1); + // Make s point to the last 64 bytes of input. + s = last64; + w.first += ((len - 1) & 63); + v.first += w.first; + w.first += v.first; + x = rotate(x + y + v.first + fetch(s + 8), 37) * mul; + y = rotate(y + v.second + fetch(s + 48), 42) * mul; + x ^= w.second * 9; + y += v.first * 9 + fetch(s + 40); + z = rotate(z + w.first, 33) * mul; + v = weak_hash_len32_with_seeds(s, v.second * mul, x + w.first); + w = weak_hash_len32_with_seeds(s + 32, z + w.second, y + fetch(s + 16)); + std::swap(z, x); + return hash_len16(hash_len16(v.first, w.first, mul) + shift_mix(y) * k0 + z, + hash_len16(v.second, w.second, mul) + x, + mul); +} + +uint64_t hash64(const std::string& str) { + return hash64(str.data(), str.size()); +} - uint64_t hash64(const std::string& str) { - return hash64(str.data(), str.size()); - } } void StringToHashBucket::validate_and_infer_types() { diff --git a/src/tensorflow_translators.cpp b/src/tensorflow_translators.cpp index bd8cba981..e279d7521 100644 --- a/src/tensorflow_translators.cpp +++ b/src/tensorflow_translators.cpp @@ -22,39 +22,40 @@ #include "regex_split.hpp" #include "string_to_hash_bucket.hpp" #include "vocab_encoder.hpp" +#include "wordpiece_tokenizer.hpp" #ifdef ENABLE_FAST_TOKENIZERS #include "case_fold.hpp" #include "normalize_unicode.hpp" -#include "wordpiece_tokenizer.hpp" #endif // ENABLE_FAST_TOKENIZERS -using namespace TemplateExtension; using namespace ov; using namespace ov::op; using namespace ov::frontend; using namespace ov::opset13; namespace { - template - T extract_scalar_const_value(const std::shared_ptr& node, const std::string& const_name) { - auto const_node = as_type_ptr(node); - FRONT_END_GENERAL_CHECK(const_node, "Conversion expects " + const_name + " to be constant."); - std::vector const_value = const_node->cast_vector(); - FRONT_END_GENERAL_CHECK(const_value.size() == 1, "Conversion expects " + const_name + " to be a scalar."); - return const_value[0]; - } - Output compute_subgraph_scalar_rank(const Output& output, element::Type output_type, bool as_scalar) { - auto shape_of = std::make_shared(output, output_type); - auto rank_of = std::make_shared(shape_of, output_type); +template +T extract_scalar_const_value(const std::shared_ptr& node, const std::string& const_name) { + auto const_node = as_type_ptr(node); + FRONT_END_GENERAL_CHECK(const_node, "Conversion expects " + const_name + " to be constant."); + std::vector const_value = const_node->cast_vector(); + FRONT_END_GENERAL_CHECK(const_value.size() == 1, "Conversion expects " + const_name + " to be a scalar."); + return const_value[0]; +} - if (as_scalar) { - auto const_zero = std::make_shared(element::i32, Shape{}, 0); - return std::make_shared(rank_of, const_zero); - } - return rank_of; +Output compute_subgraph_scalar_rank(const Output& output, element::Type output_type, bool as_scalar) { + auto shape_of = std::make_shared(output, output_type); + auto rank_of = std::make_shared(shape_of, output_type); + + if (as_scalar) { + auto const_zero = std::make_shared(element::i32, Shape{}, 0); + return std::make_shared(rank_of, const_zero); } + return rank_of; +} + } // namespace OutputVector translate_sentencepiece_op(const NodeContext& node) { @@ -195,8 +196,6 @@ ov::OutputVector translate_regex_split_with_offsets(const ov::frontend::NodeCont return { post_translate_ragged_tensor_output({outputs[0], outputs[1], flatten_string_tensor}) }; } -#ifdef ENABLE_FAST_TOKENIZERS - ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::NodeContext& node) { FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "WordpieceTokenizeWithOffsets expects 2 inputs"); ov::OutputVector inputs = pre_translate_ragged_string_tensor_input(node.get_input(0)); @@ -222,6 +221,8 @@ ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::N return { post_translate_ragged_tensor_output(wp_tokenizer->outputs()) }; } +#ifdef ENABLE_FAST_TOKENIZERS + ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node) { auto node_name = node.get_name(); FRONT_END_GENERAL_CHECK(node.get_input_size() == 1, "StringLower expects only 1 input"); diff --git a/src/tensorflow_translators.hpp b/src/tensorflow_translators.hpp index 970f45431..4dbc26b23 100644 --- a/src/tensorflow_translators.hpp +++ b/src/tensorflow_translators.hpp @@ -16,10 +16,10 @@ ov::OutputVector translate_ragged_tensor_to_tensor(const ov::frontend::NodeConte ov::OutputVector translate_equal(const ov::frontend::NodeContext& node); ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeContext& node); ov::OutputVector translate_squeeze_op(const ov::frontend::NodeContext& node); +ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::NodeContext& node); #ifdef ENABLE_FAST_TOKENIZERS ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node); ov::OutputVector translate_case_fold_utf8(const ov::frontend::NodeContext& node); ov::OutputVector translate_normalize_utf8(const ov::frontend::NodeContext& node); -ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::NodeContext& node); #endif // ENABLE_FAST_TOKENIZERS