Skip to content

Commit

Permalink
Add Separate Special Token Handling To Sentencepiece (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#197)

(cherry picked from commit b08a4ae)
  • Loading branch information
apaniukov committed Jul 18, 2024
1 parent 04795c1 commit b76754a
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 43 deletions.
23 changes: 17 additions & 6 deletions python/openvino_tokenizers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_parser() -> ArgumentParser:
type=str,
help=(
"The model id of a tokenizer hosted inside a model repo on huggingface.co "
"or a path to a saved Huggingface tokenizer directory"
"or a path to a saved Huggingface tokenizer directory."
),
)
parser.add_argument(
Expand Down Expand Up @@ -152,7 +152,17 @@ def get_parser() -> ArgumentParser:
"tokenizer and then converts it to OpenVINO. Might result in slightly different tokenizer. "
"See models with _slow suffix https://github.com/openvinotoolkit/openvino_contrib/tree/master/modules/"
"custom_operations/user_ie_extensions/tokenizer/python#output-match-by-model to check the potential "
"difference between original and OpenVINO tokenizers"
"difference between original and OpenVINO tokenizers."
),
)
parser.add_argument(
"--handle-special-tokens-with-re",
"--handle_special_tokens_with_re",
required=False,
action="store_true",
help=(
"Use separete regex to handle special tokens for sentencepiece-based tokenizers. Use this option if the "
"converted tokenizer doesn't use special tokens during tokenization."
),
)
parser.add_argument(
Expand All @@ -162,7 +172,7 @@ def get_parser() -> ArgumentParser:
action="store_true",
help=(
"Pass `trust_remote_code=True` to `AutoTokenizer.from_pretrained`. It will "
"execute code present on the Hub on your local machine"
"execute code present on the Huggingface Hub on your machine!"
),
)
parser.add_argument(
Expand All @@ -172,7 +182,7 @@ def get_parser() -> ArgumentParser:
action=StringToTypeAction,
default=Type.i64,
choices=["i32", "i64"],
help="Type of the output tensors for tokenizer",
help="Type of the output tensors for tokenizer.",
)
parser.add_argument(
"--detokenizer-input-type",
Expand All @@ -181,7 +191,7 @@ def get_parser() -> ArgumentParser:
action=StringToTypeAction,
default=Type.i64,
choices=["i32", "i64"],
help="Type of the input tensor for detokenizer",
help="Type of the input tensor for detokenizer.",
)
parser.add_argument(
"--streaming-detokenizer",
Expand All @@ -190,7 +200,7 @@ def get_parser() -> ArgumentParser:
action="store_true",
help=(
"[Experimental] Modify SentencePiece based detokenizer to keep spaces leading space. "
"Can be used to stream a model output without TextStreamer buffer"
"Can be used to stream a model output without TextStreamer buffer."
),
)
return parser
Expand Down Expand Up @@ -232,6 +242,7 @@ def convert_hf_tokenizer() -> None:
detokenizer_input_type=args.detokenizer_input_type,
streaming_detokenizer=args.streaming_detokenizer,
use_max_padding=args.max_padding is not None,
handle_special_tokens_with_re=args.handle_special_tokens_with_re,
)
if not isinstance(converted, tuple):
converted = (converted,)
Expand Down
2 changes: 2 additions & 0 deletions python/openvino_tokenizers/convert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def convert_tokenizer(
detokenizer_input_type: Type = Type.i64,
streaming_detokenizer: bool = False,
use_max_padding: bool = False,
handle_special_tokens_with_re: bool = False,
) -> Union[Model, Tuple[Model, Model]]:
ov_tokenizers = None

Expand All @@ -50,6 +51,7 @@ def convert_tokenizer(
add_special_tokens=add_special_tokens,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
handle_special_tokens_with_re=handle_special_tokens_with_re,
)
elif is_tiktoken_model(tokenizer_object):
logger.info("Convert tiktoken-based tokenizer")
Expand Down
17 changes: 14 additions & 3 deletions python/openvino_tokenizers/hf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TOKENIZER_NAME,
)
from .tokenizer_pipeline import (
BasePipelineStep,
BPETokenizationStep,
ByteFallbackStep,
BytesToCharsStep,
Expand Down Expand Up @@ -518,6 +519,7 @@ def convert_sentencepiece_model_tokenizer(
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
add_prefix_space: Optional[bool] = None,
handle_special_tokens_with_re: bool = False,
) -> Union[Model, Tuple[Model, Model]]:
if not is_sentencepiece_model(hf_tokenizer):
raise OVTypeError("Cannot convert tokenizer of this type without `.model` file.")
Expand Down Expand Up @@ -604,7 +606,7 @@ def convert_sentencepiece_model_tokenizer(
add_tokens=add_tokens,
hf_tokenizer=hf_tokenizer,
skip_special_tokens=False,
add_prefix_space=add_prefix_space,
add_prefix_space=add_prefix_space and not handle_special_tokens_with_re,
)
sp_model = np.frombuffer(sp_model_string, dtype=np.uint8)
sp_model_node = as_node(sp_model)
Expand All @@ -623,16 +625,25 @@ def convert_sentencepiece_model_tokenizer(
input_node.set_friendly_name("string_input")
next_node = input_node.outputs()

if prepend_scheme == "first":
if prepend_scheme == "first" or (add_prefix_space and handle_special_tokens_with_re):
next_node = _get_factory().create("StringTensorUnpack", next_node).outputs()
next_node = RegexNormalizationStep.add_prefix_whitespace_to_not_whitespace_regex().get_ov_subgraph(next_node)
next_node = _get_factory().create("StringTensorPack", next_node).outputs()

do_left_padding = hf_tokenizer.padding_side == "left"

if handle_special_tokens_with_re:
tokens, ids = zip(*sorted(((token, id) for id, token in add_tokens.items()), reverse=True))
added_inputs = [
*BasePipelineStep.create_string_constant_node(tokens).outputs(),
make_constant_node(np.array(ids, dtype=np.int32), Type.i32).output(0),
]
else:
added_inputs = []

tokenizer_node = _get_factory().create(
"SentencepieceTokenizer",
[sp_model_node, *next_node],
[sp_model_node, *next_node] + added_inputs,
{
"add_bos": add_bos_token,
"add_eos": add_eos_token,
Expand Down
114 changes: 93 additions & 21 deletions src/sentence_piece.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <functional>

#include "sentencepiece_processor.h"
#include "absl/container/flat_hash_map.h"

#include "openvino/op/util/framework_node.hpp"
#include "openvino/opsets/opset13.hpp"
Expand Down Expand Up @@ -72,7 +71,7 @@ std::string form_extra_options(bool add_bos, bool add_eos, bool reverse) {

void init_sp_model(const OutputVector& args, std::shared_ptr<SentencePieceProcessor>& sp) {
auto sp_model_const = as_type_ptr<Constant>(args[0].get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(sp_model_const, "SentencepieceTokenizer expects SentencePiece model to be constant.");
OPENVINO_ASSERT(sp_model_const, "SentencepieceTokenizer expects SentencePiece model to be constant.");
auto spm_model = static_cast<const char*>(sp_model_const->get_data_ptr());
auto spm_model_size = sp_model_const->get_byte_size();

Expand All @@ -93,40 +92,60 @@ SentencepieceTokenizer::SentencepieceTokenizer(const OutputVector& args, int32_t
m_reverse(reverse), Op(args) {

init_sp_model(args, m_sp);
CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, m_reverse)));
auto do_reverse = (m_reverse && get_input_size() < 5); // do not reverse if special_tokens_re is used
CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, do_reverse)));
constructor_validate_and_infer_types();
}

SentencepieceTokenizer::SentencepieceTokenizer(const OutputVector& args, const std::shared_ptr<SentencePieceProcessor>& sp,
int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse) :
SentencepieceTokenizer::SentencepieceTokenizer(
const OutputVector& args,
const std::shared_ptr<SentencePieceProcessor>& sp,
const std::shared_ptr<re2::RE2>& special_tokens_re,
const std::shared_ptr<absl::flat_hash_map<std::string, int32_t>>& special_tokens_map,
int32_t nbest_size,
float alpha,
bool add_bos,
bool add_eos,
bool reverse
) :
m_sp((sp == nullptr) ? std::make_shared<SentencePieceProcessor>(): sp),
m_special_tokens_re(special_tokens_re),
m_special_tokens_map(special_tokens_map),
m_nbest_size(nbest_size), m_alpha(alpha), m_add_bos(add_bos), m_add_eos(add_eos),
m_reverse(reverse), Op(args) {
// constructor above without sp argument never called when the node is created with python factory, so need to init and cache m_sp here
if (!m_sp->status().ok()) {
init_sp_model(args, m_sp);
CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, m_reverse)));
auto do_reverse = (m_reverse && get_input_size() < 5); // do not reverse if special_tokens_re is used
CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, do_reverse)));
};
constructor_validate_and_infer_types();
}

void SentencepieceTokenizer::validate_and_infer_types() {
FRONT_END_GENERAL_CHECK(get_input_element_type(0) == element::u8, "SentencepieceTokenizer accepts sp model as the first input and it should be of type u8 tensor");
OPENVINO_ASSERT(get_input_element_type(0) == element::u8, "SentencepieceTokenizer accepts sp model as the first input and it should be of type u8 tensor");

auto input_size = get_input_size();
if(input_size == 2) {
FRONT_END_GENERAL_CHECK(
// sentencepiece model, string input, (unpacked special tokens)
if(input_size == 2 || input_size == 6) {
OPENVINO_ASSERT(
// WA: f32 appeared as a placeholder for unknown type during intermediate conversion steps
get_input_element_type(1) == element::string || get_input_element_type(1) == element::f32,
"SentencepieceTokenizer accepts sentences as the second input and it should be of type string tensor");
} else if (input_size == 4) {
FRONT_END_GENERAL_CHECK(get_input_element_type(1) == element::i32, "SentencepieceTokenizer accepts begins offsets as the second and it should be of type i32 tensor");
FRONT_END_GENERAL_CHECK(get_input_element_type(2) == element::i32, "SentencepieceTokenizer accepts ends offsets as the third and it should be of type i32 tensor");
FRONT_END_GENERAL_CHECK(get_input_element_type(3) == element::u8, "SentencepieceTokenizer accepts sentence symbols as the fourth input and it should be of type u8 tensor");
// sentencepiece model, unpacked string input, (unpacked special tokens)
} else if (input_size == 4 || input_size == 8) {
check_string_input(this, 1);
} else {
OPENVINO_THROW("Unexpected input format. SentencepieceTokenizer accepts one string input or three decomposed string inputs (begins, ends, symbols)");
};

if (input_size == 6 || input_size == 8) {
// unpacked special tokens
check_string_input(this, input_size - 4);
// special tokens ids
OPENVINO_ASSERT(this->get_input_element_type(input_size - 1) == element::i32, "Expected an i32 tensor for special tokens ids.");
};

// The operation SentencepieceTokenizerExtensionOp has three outputs: sparse indices, sparse values
// and dense shape
set_output_type(0, element::i64, PartialShape{ Dimension(), Dimension(2) });
Expand All @@ -144,17 +163,40 @@ bool SentencepieceTokenizer::visit_attributes(AttributeVisitor& visitor) {
}

bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
auto input_size = get_input_size();
if (m_sp == nullptr) {
m_sp = std::make_shared<SentencePieceProcessor>();
init_sp_model_in_eval(inputs, m_sp);
CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, m_reverse)));
auto do_reverse = (m_reverse && input_size < 5); // do not reverse if special_tokens_re is used
CHECK_OK(m_sp->SetEncodeExtraOptions(form_extra_options(m_add_bos, m_add_eos, do_reverse)));
};
if (input_size > 5 && m_special_tokens_re == nullptr) {
auto special_tokens_begins = inputs[input_size - 4].data<const int32_t>();
auto special_tokens_ends = inputs[input_size - 3].data<const int32_t>();
auto special_tokens_chars = inputs[input_size - 2].data<const uint8_t>();
auto special_tokens_ids = inputs[input_size - 1].data<const int32_t>();

std::string special_tokens;
m_special_tokens_map = std::make_shared<absl::flat_hash_map<std::string, int32_t>>();
for (size_t i = 0; i < inputs[input_size - 4].get_size(); ++i) {
const std::string token = std::string(
special_tokens_chars + special_tokens_begins[i],
special_tokens_chars + special_tokens_ends[i]
);
if (!special_tokens.empty()) {
special_tokens += "|";
};
special_tokens += re2::RE2::QuoteMeta(token);

m_special_tokens_map->insert(std::pair{token, special_tokens_ids[i]});
};
m_special_tokens_re = std::make_shared<re2::RE2>("(" + special_tokens + ")");
};

std::vector<int64_t> sparse_indices;
std::vector<int32_t> sparse_values;
std::vector<int64_t> sparse_dense_shape;

auto input_size = get_input_size();
int32_t batch_size;

// used in case of string tensors
Expand All @@ -165,7 +207,7 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector&
const int32_t* end_ids;
const uint8_t* data;

if (input_size == 2) {
if (input_size == 2 || input_size == 6) {
auto input_element_type = get_input_element_type(1);
if(input_element_type == ov::element::string) {
strings = inputs[1].data<const std::string>();
Expand All @@ -182,8 +224,8 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector&

size_t max_token_id = 0;
for (size_t batch_ind = 0; batch_ind < batch_size; ++batch_ind) {
absl::string_view sentence;
if (input_size == 2) {
std::string sentence;
if (input_size == 2 || input_size == 6) {
sentence = strings[batch_ind];
} else {
auto begin_ind = begin_ids[batch_ind];
Expand All @@ -192,13 +234,43 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector&
};

std::vector<int32_t> ids;
CHECK_OK(m_sp->SampleEncode(sentence, m_nbest_size, m_alpha, &ids));
if (input_size < 5) {
CHECK_OK(m_sp->SampleEncode(sentence, m_nbest_size, m_alpha, &ids));
} else {
std::string special_token;
std::vector<int32_t> part_ids;
re2::StringPiece input(sentence);
auto cursor = input.begin();
const auto num_tokens_before = ids.size();
while (cursor != input.end()) {
if (re2::RE2::FindAndConsume(&input, *m_special_tokens_re, &special_token)) {
auto before_special_token = absl::string_view(cursor, input.begin() - cursor - special_token.size());
CHECK_OK(m_sp->SampleEncode(before_special_token, m_nbest_size, m_alpha, &part_ids));
ids.insert(ids.end(), part_ids.begin(), part_ids.end());
cursor = input.begin();

auto token_and_id = m_special_tokens_map->find(special_token);
if (token_and_id != m_special_tokens_map->end()) {
ids.push_back(token_and_id->second);
};
} else {
CHECK_OK(m_sp->SampleEncode(input, m_nbest_size, m_alpha, &part_ids));
ids.insert(ids.end(), part_ids.begin(), part_ids.end());
cursor = input.end();
};
};

if (m_reverse && ids.size() - num_tokens_before > 1) {
std::reverse(ids.begin() + num_tokens_before, ids.end());
};
};

// put into resulted vectors
for (size_t token_id = 0; token_id < ids.size(); ++token_id) {
sparse_indices.push_back(static_cast<int64_t>(batch_ind));
sparse_indices.push_back(static_cast<int64_t>(token_id));
sparse_values.push_back(static_cast<int32_t>(ids[token_id]));
}
};
max_token_id = max_token_id < ids.size() ? ids.size() : max_token_id;
}
sparse_dense_shape.push_back(static_cast<int64_t>(batch_size));
Expand All @@ -219,7 +291,7 @@ bool SentencepieceTokenizer::has_evaluate() const {
}

std::shared_ptr<Node> SentencepieceTokenizer::clone_with_new_inputs(const OutputVector& new_args) const {
return std::make_shared<SentencepieceTokenizer>(new_args, m_sp, m_nbest_size, m_alpha, m_add_bos, m_add_eos, m_reverse);
return std::make_shared<SentencepieceTokenizer>(new_args, m_sp, m_special_tokens_re, m_special_tokens_map, m_nbest_size, m_alpha, m_add_bos, m_add_eos, m_reverse);
}


Expand Down
17 changes: 15 additions & 2 deletions src/sentence_piece.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <openvino/op/op.hpp>
#include "absl/strings/str_format.h"
#include "absl/container/flat_hash_map.h"
#include "re2/re2.h"

namespace sentencepiece {
class SentencePieceProcessor;
Expand All @@ -19,8 +21,17 @@ namespace TemplateExtension {

SentencepieceTokenizer() = default;
SentencepieceTokenizer(const ov::OutputVector& args, int32_t nbest_size, float alpha, bool add_bos, bool add_eos, bool reverse);
SentencepieceTokenizer(const ov::OutputVector& args, const std::shared_ptr<sentencepiece::SentencePieceProcessor>& sp, int32_t nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse);
SentencepieceTokenizer(
const ov::OutputVector& args,
const std::shared_ptr<sentencepiece::SentencePieceProcessor>& sp,
const std::shared_ptr<re2::RE2>& special_tokens_re,
const std::shared_ptr<absl::flat_hash_map<std::string, int32_t>>& special_tokens_map,
int32_t nbest_size,
float alpha,
bool add_bos,
bool add_eos,
bool reverse
);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

Expand All @@ -34,6 +45,8 @@ namespace TemplateExtension {

private:
mutable std::shared_ptr<sentencepiece::SentencePieceProcessor> m_sp;
mutable std::shared_ptr<re2::RE2> m_special_tokens_re;
mutable std::shared_ptr<absl::flat_hash_map<std::string, int32_t>> m_special_tokens_map;
int32_t m_nbest_size;
float m_alpha;
bool m_add_bos;
Expand Down
Loading

0 comments on commit b76754a

Please sign in to comment.