diff --git a/python/openvino_tokenizers/cli.py b/python/openvino_tokenizers/cli.py index bc1293f1c..d14b3e084 100644 --- a/python/openvino_tokenizers/cli.py +++ b/python/openvino_tokenizers/cli.py @@ -45,7 +45,7 @@ def get_parser() -> ArgumentParser: type=str, help=( "The model id of a tokenizer hosted inside a model repo on huggingface.co " - "or a path to a saved Huggingface tokenizer directory" + "or a path to a saved Huggingface tokenizer directory." ), ) parser.add_argument( @@ -152,7 +152,17 @@ def get_parser() -> ArgumentParser: "tokenizer and then converts it to OpenVINO. Might result in slightly different tokenizer. " "See models with _slow suffix https://github.com/openvinotoolkit/openvino_contrib/tree/master/modules/" "custom_operations/user_ie_extensions/tokenizer/python#output-match-by-model to check the potential " - "difference between original and OpenVINO tokenizers" + "difference between original and OpenVINO tokenizers." + ), + ) + parser.add_argument( + "--handle-special-tokens-with-re", + "--handle_special_tokens_with_re", + required=False, + action="store_true", + help=( + "Use separete regex to handle special tokens for sentencepiece-based tokenizers. Use this option if the " + "converted tokenizer doesn't use special tokens during tokenization." ), ) parser.add_argument( @@ -162,7 +172,7 @@ def get_parser() -> ArgumentParser: action="store_true", help=( "Pass `trust_remote_code=True` to `AutoTokenizer.from_pretrained`. It will " - "execute code present on the Hub on your local machine" + "execute code present on the Huggingface Hub on your machine!" ), ) parser.add_argument( @@ -172,7 +182,7 @@ def get_parser() -> ArgumentParser: action=StringToTypeAction, default=Type.i64, choices=["i32", "i64"], - help="Type of the output tensors for tokenizer", + help="Type of the output tensors for tokenizer.", ) parser.add_argument( "--detokenizer-input-type", @@ -181,7 +191,7 @@ def get_parser() -> ArgumentParser: action=StringToTypeAction, default=Type.i64, choices=["i32", "i64"], - help="Type of the input tensor for detokenizer", + help="Type of the input tensor for detokenizer.", ) parser.add_argument( "--streaming-detokenizer", @@ -190,7 +200,7 @@ def get_parser() -> ArgumentParser: action="store_true", help=( "[Experimental] Modify SentencePiece based detokenizer to keep spaces leading space. " - "Can be used to stream a model output without TextStreamer buffer" + "Can be used to stream a model output without TextStreamer buffer." ), ) return parser @@ -232,6 +242,7 @@ def convert_hf_tokenizer() -> None: detokenizer_input_type=args.detokenizer_input_type, streaming_detokenizer=args.streaming_detokenizer, use_max_padding=args.max_padding is not None, + handle_special_tokens_with_re=args.handle_special_tokens_with_re, ) if not isinstance(converted, tuple): converted = (converted,) diff --git a/python/openvino_tokenizers/convert_tokenizer.py b/python/openvino_tokenizers/convert_tokenizer.py index 48b9398fc..2e4ce3640 100644 --- a/python/openvino_tokenizers/convert_tokenizer.py +++ b/python/openvino_tokenizers/convert_tokenizer.py @@ -25,6 +25,7 @@ def convert_tokenizer( detokenizer_input_type: Type = Type.i64, streaming_detokenizer: bool = False, use_max_padding: bool = False, + handle_special_tokens_with_re: bool = False, ) -> Union[Model, Tuple[Model, Model]]: ov_tokenizers = None @@ -50,6 +51,7 @@ def convert_tokenizer( add_special_tokens=add_special_tokens, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, + handle_special_tokens_with_re=handle_special_tokens_with_re, ) elif is_tiktoken_model(tokenizer_object): logger.info("Convert tiktoken-based tokenizer") diff --git a/python/openvino_tokenizers/hf_parser.py b/python/openvino_tokenizers/hf_parser.py index 80ca60e19..2b8e9d231 100644 --- a/python/openvino_tokenizers/hf_parser.py +++ b/python/openvino_tokenizers/hf_parser.py @@ -30,6 +30,7 @@ TOKENIZER_NAME, ) from .tokenizer_pipeline import ( + BasePipelineStep, BPETokenizationStep, ByteFallbackStep, BytesToCharsStep, @@ -518,6 +519,7 @@ def convert_sentencepiece_model_tokenizer( skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = False, add_prefix_space: Optional[bool] = None, + handle_special_tokens_with_re: bool = False, ) -> Union[Model, Tuple[Model, Model]]: if not is_sentencepiece_model(hf_tokenizer): raise OVTypeError("Cannot convert tokenizer of this type without `.model` file.") @@ -604,7 +606,7 @@ def convert_sentencepiece_model_tokenizer( add_tokens=add_tokens, hf_tokenizer=hf_tokenizer, skip_special_tokens=False, - add_prefix_space=add_prefix_space, + add_prefix_space=add_prefix_space and not handle_special_tokens_with_re, ) sp_model = np.frombuffer(sp_model_string, dtype=np.uint8) sp_model_node = as_node(sp_model) @@ -623,16 +625,25 @@ def convert_sentencepiece_model_tokenizer( input_node.set_friendly_name("string_input") next_node = input_node.outputs() - if prepend_scheme == "first": + if prepend_scheme == "first" or (add_prefix_space and handle_special_tokens_with_re): next_node = _get_factory().create("StringTensorUnpack", next_node).outputs() next_node = RegexNormalizationStep.add_prefix_whitespace_to_not_whitespace_regex().get_ov_subgraph(next_node) next_node = _get_factory().create("StringTensorPack", next_node).outputs() do_left_padding = hf_tokenizer.padding_side == "left" + if handle_special_tokens_with_re: + tokens, ids = zip(*sorted(((token, id) for id, token in add_tokens.items()), reverse=True)) + added_inputs = [ + *BasePipelineStep.create_string_constant_node(tokens).outputs(), + make_constant_node(np.array(ids, dtype=np.int32), Type.i32).output(0), + ] + else: + added_inputs = [] + tokenizer_node = _get_factory().create( "SentencepieceTokenizer", - [sp_model_node, *next_node], + [sp_model_node, *next_node] + added_inputs, { "add_bos": add_bos_token, "add_eos": add_eos_token, diff --git a/src/sentence_piece.cpp b/src/sentence_piece.cpp index e31736b07..9912580d1 100644 --- a/src/sentence_piece.cpp +++ b/src/sentence_piece.cpp @@ -5,7 +5,6 @@ #include #include "sentencepiece_processor.h" -#include "absl/container/flat_hash_map.h" #include "openvino/op/util/framework_node.hpp" #include "openvino/opsets/opset13.hpp" @@ -72,7 +71,7 @@ std::string form_extra_options(bool add_bos, bool add_eos, bool reverse) { void init_sp_model(const OutputVector& args, std::shared_ptr& sp) { auto sp_model_const = as_type_ptr(args[0].get_node_shared_ptr()); - FRONT_END_GENERAL_CHECK(sp_model_const, "SentencepieceTokenizer expects SentencePiece model to be constant."); + OPENVINO_ASSERT(sp_model_const, "SentencepieceTokenizer expects SentencePiece model to be constant."); auto spm_model = static_cast(sp_model_const->get_data_ptr()); auto spm_model_size = sp_model_const->get_byte_size(); @@ -93,40 +92,60 @@ SentencepieceTokenizer::SentencepieceTokenizer(const OutputVector& args, int32_t m_reverse(reverse), Op(args) { init_sp_model(args, m_sp); - CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, m_reverse))); + auto do_reverse = (m_reverse && get_input_size() < 5); // do not reverse if special_tokens_re is used + CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, do_reverse))); constructor_validate_and_infer_types(); } -SentencepieceTokenizer::SentencepieceTokenizer(const OutputVector& args, const std::shared_ptr& sp, - int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse) : +SentencepieceTokenizer::SentencepieceTokenizer( + const OutputVector& args, + const std::shared_ptr& sp, + const std::shared_ptr& special_tokens_re, + const std::shared_ptr>& special_tokens_map, + int32_t nbest_size, + float alpha, + bool add_bos, + bool add_eos, + bool reverse +) : m_sp((sp == nullptr) ? std::make_shared(): sp), + m_special_tokens_re(special_tokens_re), + m_special_tokens_map(special_tokens_map), m_nbest_size(nbest_size), m_alpha(alpha), m_add_bos(add_bos), m_add_eos(add_eos), m_reverse(reverse), Op(args) { // constructor above without sp argument never called when the node is created with python factory, so need to init and cache m_sp here if (!m_sp->status().ok()) { init_sp_model(args, m_sp); - CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, m_reverse))); + auto do_reverse = (m_reverse && get_input_size() < 5); // do not reverse if special_tokens_re is used + CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, do_reverse))); }; constructor_validate_and_infer_types(); } void SentencepieceTokenizer::validate_and_infer_types() { - FRONT_END_GENERAL_CHECK(get_input_element_type(0) == element::u8, "SentencepieceTokenizer accepts sp model as the first input and it should be of type u8 tensor"); + OPENVINO_ASSERT(get_input_element_type(0) == element::u8, "SentencepieceTokenizer accepts sp model as the first input and it should be of type u8 tensor"); auto input_size = get_input_size(); - if(input_size == 2) { - FRONT_END_GENERAL_CHECK( + // sentencepiece model, string input, (unpacked special tokens) + if(input_size == 2 || input_size == 6) { + OPENVINO_ASSERT( // WA: f32 appeared as a placeholder for unknown type during intermediate conversion steps get_input_element_type(1) == element::string || get_input_element_type(1) == element::f32, "SentencepieceTokenizer accepts sentences as the second input and it should be of type string tensor"); - } else if (input_size == 4) { - FRONT_END_GENERAL_CHECK(get_input_element_type(1) == element::i32, "SentencepieceTokenizer accepts begins offsets as the second and it should be of type i32 tensor"); - FRONT_END_GENERAL_CHECK(get_input_element_type(2) == element::i32, "SentencepieceTokenizer accepts ends offsets as the third and it should be of type i32 tensor"); - FRONT_END_GENERAL_CHECK(get_input_element_type(3) == element::u8, "SentencepieceTokenizer accepts sentence symbols as the fourth input and it should be of type u8 tensor"); + // sentencepiece model, unpacked string input, (unpacked special tokens) + } else if (input_size == 4 || input_size == 8) { + check_string_input(this, 1); } else { OPENVINO_THROW("Unexpected input format. SentencepieceTokenizer accepts one string input or three decomposed string inputs (begins, ends, symbols)"); }; + if (input_size == 6 || input_size == 8) { + // unpacked special tokens + check_string_input(this, input_size - 4); + // special tokens ids + OPENVINO_ASSERT(this->get_input_element_type(input_size - 1) == element::i32, "Expected an i32 tensor for special tokens ids."); + }; + // The operation SentencepieceTokenizerExtensionOp has three outputs: sparse indices, sparse values // and dense shape set_output_type(0, element::i64, PartialShape{ Dimension(), Dimension(2) }); @@ -144,17 +163,40 @@ bool SentencepieceTokenizer::visit_attributes(AttributeVisitor& visitor) { } bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector& inputs) const { + auto input_size = get_input_size(); if (m_sp == nullptr) { m_sp = std::make_shared(); init_sp_model_in_eval(inputs, m_sp); - CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, m_reverse))); + auto do_reverse = (m_reverse && input_size < 5); // do not reverse if special_tokens_re is used + CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, do_reverse))); + }; + if (input_size > 5 && m_special_tokens_re == nullptr) { + auto special_tokens_begins = inputs[input_size - 4].data(); + auto special_tokens_ends = inputs[input_size - 3].data(); + auto special_tokens_chars = inputs[input_size - 2].data(); + auto special_tokens_ids = inputs[input_size - 1].data(); + + std::string special_tokens; + m_special_tokens_map = std::make_shared>(); + for (size_t i = 0; i < inputs[input_size - 4].get_size(); ++i) { + const std::string token = std::string( + special_tokens_chars + special_tokens_begins[i], + special_tokens_chars + special_tokens_ends[i] + ); + if (!special_tokens.empty()) { + special_tokens += "|"; + }; + special_tokens += re2::RE2::QuoteMeta(token); + + m_special_tokens_map->insert(std::pair{token, special_tokens_ids[i]}); + }; + m_special_tokens_re = std::make_shared("(" + special_tokens + ")"); }; std::vector sparse_indices; std::vector sparse_values; std::vector sparse_dense_shape; - auto input_size = get_input_size(); int32_t batch_size; // used in case of string tensors @@ -165,7 +207,7 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector& const int32_t* end_ids; const uint8_t* data; - if (input_size == 2) { + if (input_size == 2 || input_size == 6) { auto input_element_type = get_input_element_type(1); if(input_element_type == ov::element::string) { strings = inputs[1].data(); @@ -182,8 +224,8 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector& size_t max_token_id = 0; for (size_t batch_ind = 0; batch_ind < batch_size; ++batch_ind) { - absl::string_view sentence; - if (input_size == 2) { + std::string sentence; + if (input_size == 2 || input_size == 6) { sentence = strings[batch_ind]; } else { auto begin_ind = begin_ids[batch_ind]; @@ -192,13 +234,43 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector& }; std::vector ids; - CHECK_OK(m_sp->SampleEncode(sentence, m_nbest_size, m_alpha, &ids)); + if (input_size < 5) { + CHECK_OK(m_sp->SampleEncode(sentence, m_nbest_size, m_alpha, &ids)); + } else { + std::string special_token; + std::vector part_ids; + re2::StringPiece input(sentence); + auto cursor = input.begin(); + const auto num_tokens_before = ids.size(); + while (cursor != input.end()) { + if (re2::RE2::FindAndConsume(&input, *m_special_tokens_re, &special_token)) { + auto before_special_token = absl::string_view(cursor, input.begin() - cursor - special_token.size()); + CHECK_OK(m_sp->SampleEncode(before_special_token, m_nbest_size, m_alpha, &part_ids)); + ids.insert(ids.end(), part_ids.begin(), part_ids.end()); + cursor = input.begin(); + + auto token_and_id = m_special_tokens_map->find(special_token); + if (token_and_id != m_special_tokens_map->end()) { + ids.push_back(token_and_id->second); + }; + } else { + CHECK_OK(m_sp->SampleEncode(input, m_nbest_size, m_alpha, &part_ids)); + ids.insert(ids.end(), part_ids.begin(), part_ids.end()); + cursor = input.end(); + }; + }; + + if (m_reverse && ids.size() - num_tokens_before > 1) { + std::reverse(ids.begin() + num_tokens_before, ids.end()); + }; + }; + // put into resulted vectors for (size_t token_id = 0; token_id < ids.size(); ++token_id) { sparse_indices.push_back(static_cast(batch_ind)); sparse_indices.push_back(static_cast(token_id)); sparse_values.push_back(static_cast(ids[token_id])); - } + }; max_token_id = max_token_id < ids.size() ? ids.size() : max_token_id; } sparse_dense_shape.push_back(static_cast(batch_size)); @@ -219,7 +291,7 @@ bool SentencepieceTokenizer::has_evaluate() const { } std::shared_ptr SentencepieceTokenizer::clone_with_new_inputs(const OutputVector& new_args) const { - return std::make_shared(new_args, m_sp, m_nbest_size, m_alpha, m_add_bos, m_add_eos, m_reverse); + return std::make_shared(new_args, m_sp, m_special_tokens_re, m_special_tokens_map, m_nbest_size, m_alpha, m_add_bos, m_add_eos, m_reverse); } diff --git a/src/sentence_piece.hpp b/src/sentence_piece.hpp index 0d6655a1d..db53f3b36 100644 --- a/src/sentence_piece.hpp +++ b/src/sentence_piece.hpp @@ -6,6 +6,8 @@ #include #include "absl/strings/str_format.h" +#include "absl/container/flat_hash_map.h" +#include "re2/re2.h" namespace sentencepiece { class SentencePieceProcessor; @@ -19,8 +21,17 @@ namespace TemplateExtension { SentencepieceTokenizer() = default; SentencepieceTokenizer(const ov::OutputVector& args, int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse); - SentencepieceTokenizer(const ov::OutputVector& args, const std::shared_ptr& sp, int32_t nbest_size, float alpha, - bool add_bos, bool add_eos, bool reverse); + SentencepieceTokenizer( + const ov::OutputVector& args, + const std::shared_ptr& sp, + const std::shared_ptr& special_tokens_re, + const std::shared_ptr>& special_tokens_map, + int32_t nbest_size, + float alpha, + bool add_bos, + bool add_eos, + bool reverse + ); bool visit_attributes(ov::AttributeVisitor& visitor) override; @@ -34,6 +45,8 @@ namespace TemplateExtension { private: mutable std::shared_ptr m_sp; + mutable std::shared_ptr m_special_tokens_re; + mutable std::shared_ptr> m_special_tokens_map; int32_t m_nbest_size; float m_alpha; bool m_add_bos; diff --git a/tests/tokenizers_test.py b/tests/tokenizers_test.py index ad0a96925..cb344e9f8 100644 --- a/tests/tokenizers_test.py +++ b/tests/tokenizers_test.py @@ -448,8 +448,8 @@ def test_sentencepiece_model_tokenizer(sentencepice_tokenizers, test_string, do_ for output_name, hf_result in hf_tokenized.items(): # chatglm has token_type_ids output that we omit if (ov_result := ov_tokenized.get(output_name)) is not None: - assert ov_result.shape == hf_result.shape, f"{hf_result}\n{ov_result}" - assert np.all(ov_result == hf_result), f"{hf_result}\n{ov_result}" + assert ov_result.shape == hf_result.shape, f"\n{hf_result}\n{ov_result}" + assert np.all(ov_result == hf_result), f"\n{hf_result}\n{ov_result}" @pytest.mark.parametrize( @@ -474,8 +474,8 @@ def test_hf_sentencepiece_tokenizers_multiple_strings( for output_name, hf_result in hf_tokenized.items(): if (ov_result := ov_tokenized.get(output_name)) is not None: - assert ov_result.shape == hf_result.shape, f"{hf_result}\n{ov_result}" - assert np.all(ov_result == hf_result), f"{hf_result}\n{ov_result}" + assert ov_result.shape == hf_result.shape, f"\n{hf_result}\n{ov_result}" + assert np.all(ov_result == hf_result), f"\n{hf_result}\n{ov_result}" @pytest.mark.parametrize( @@ -523,8 +523,8 @@ def test_hf_bpe_tokenizers_outputs(bpe_tokenizers, test_string, do_add_special_t for output_name, hf_result in hf_tokenized.items(): # galactica tokenizer has 3 output, but model has 2 inputs if (ov_result := ov_tokenized.get(output_name)) is not None: - assert ov_result.shape == hf_result.shape, f"{hf_result}\n{ov_result}" - assert np.all(ov_result == hf_result), f"{hf_result}\n{ov_result}" + assert ov_result.shape == hf_result.shape, f"\n{hf_result}\n{ov_result}" + assert np.all(ov_result == hf_result), f"\n{hf_result}\n{ov_result}" @pytest.mark.parametrize( @@ -549,8 +549,8 @@ def test_hf_bpe_tokenizers_multiple_strings( for output_name, hf_result in hf_tokenized.items(): if (ov_result := ov_tokenized.get(output_name)) is not None: - assert ov_result.shape == hf_result.shape, f"{hf_result}\n{ov_result}" - assert np.all(ov_result == hf_result), f"{hf_result}\n{ov_result}" + assert ov_result.shape == hf_result.shape, f"\n{hf_result}\n{ov_result}" + assert np.all(ov_result == hf_result), f"\n{hf_result}\n{ov_result}" @pytest.mark.parametrize( @@ -595,7 +595,7 @@ def test_tiktoken_tokenizers(tiktoken_tokenizers, test_string): for output_name, hf_result in hf_tokenized.items(): if (ov_result := ov_tokenized.get(output_name)) is not None: - assert np.all(ov_result == hf_result), f"{hf_result}\n{ov_result}" + assert np.all(ov_result == hf_result), f"\n{hf_result}\n{ov_result}" @pytest.mark.parametrize( @@ -619,8 +619,8 @@ def test_hf_tiktoken_tokenizers_multiple_strings( for output_name, hf_result in hf_tokenized.items(): if (ov_result := ov_tokenized.get(output_name)) is not None: - assert ov_result.shape == hf_result.shape, f"{hf_result}\n{ov_result}" - assert np.all(ov_result == hf_result), f"{hf_result}\n{ov_result}" + assert ov_result.shape == hf_result.shape, f"\n{hf_result}\n{ov_result}" + assert np.all(ov_result == hf_result), f"\n{hf_result}\n{ov_result}" @pytest.mark.parametrize(