Skip to content

Commit

Permalink
Improve BPE (openvinotoolkit#281)
Browse files Browse the repository at this point in the history
* use list

* use priority queue

* some corrections

* thread safe use of RE2

* optimizations and cleanup

* wip

* fix segfault

* use shared_ptrs instead of raw pointers

* minor corrections

* fix comparing in PriorityQueue

* minor improvements
  • Loading branch information
pavel-esir authored Oct 21, 2024
1 parent 57b236f commit 17fc3c1
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 38 deletions.
132 changes: 98 additions & 34 deletions src/bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "bpe_tokenizer.hpp"
#include "openvino/opsets/opset13.hpp"
#include "absl/strings/str_format.h"
#include <queue>
using namespace ov;
using namespace ov::opset13;

Expand Down Expand Up @@ -177,65 +178,128 @@ bool BPETokenizer::evaluate(ov::TensorVector& outputs, const ov::TensorVector& i
return true;
}


std::pair<std::pair<int32_t, int32_t>, size_t> BPETokenizerImpl::get_min_rank_pair(Tokens tokens) {
int min_rank = INT_MAX;
std::pair<int32_t, int32_t> min_rank_pair = {tokens[0], tokens[1]};
size_t position = -1;
for (size_t i = 0; i < tokens.size() - 1; i++) {
auto pair = std::pair(tokens[i], tokens[i + 1]);
if (m_merges.count(pair) && m_merges.at(pair).first < min_rank) {
min_rank = m_merges.at(pair).first;
min_rank_pair = pair;
position = i;
}
struct CompareRank {
bool operator()(const std::tuple<int32_t, int32_t, TokenNode, TokenNode, int32_t>& lhs,
const std::tuple<int32_t, int32_t, TokenNode, TokenNode, int32_t >& rhs) const {
// Compare beased on positions in merges, but if positions in merges match
// prefer pairs which are closer to the beginning of the sequence.
return (std::get<0>(lhs) != std::get<0>(rhs)) ? std::get<0>(lhs) > std::get<0>(rhs) : std::get<4>(lhs) > std::get<4>(rhs);
}
return {min_rank_pair, position};
}

};

Tokens BPETokenizerImpl::tokenize(std::string& text) {
std::vector<int32_t> BPETokenizerImpl::tokenize(std::string& text) {
if (m_cache.count(text)) {
return m_cache.at(text);
}

// For models with end_suffix (e.g. </w>) need to add suffix before looking them up in the vocabulary/prefix tree.
text += m_end_suffix;
// TODO: CVS-150387 Implement suffix_indicator.

// Initialize sequence of integer tokens by looking up
// for the longest matching sequnce in the prefix tree.
Tokens res;
res.reserve(text.length());

// Initialize sequence of integer tokens by looking up the longest match in the prefix tree.
TokensList res;
const auto text_vec = std::vector<unsigned char>(text.begin(), text.end());
for(int idx = 0; idx < text.size(); ) {
for (int idx = 0; idx < text.size();) {
auto r = m_trie->find_longest(text_vec, idx);
if (r != -1) {
res.emplace_back(r);
res.insert(r);
} else if (m_byte_fallback) {
res.emplace_back(m_vocab.at(absl::StrFormat("<0x%02X>", static_cast<unsigned char>(text[idx]))));
res.insert(m_vocab.at(absl::StrFormat("<0x%02X>", static_cast<unsigned char>(text[idx]))));
idx++;
} else {
if (!m_fuse_unk || res.back() != -1){
res.emplace_back(m_unk_token_id);
if (!m_fuse_unk || (res.tail->data) != -1) {
res.insert(m_unk_token_id);
}
idx++;
}
};
}
size_t initial_num_tokens = res.size();

while (res.size() >= 2) {
auto [pair, idx] = get_min_rank_pair(res);
if (idx == -1) {
break;
// Prepare priority queue to store pairs with their ranks.
// (position in merges, rank, iterator to first, iterator to second, replacement sequence number).
using QueueEntry = std::tuple<int32_t, int32_t, TokenNode, TokenNode, int32_t>;
std::priority_queue<QueueEntry, std::vector<QueueEntry>, CompareRank> pq;

// Fill the priority queue with initial pairs from TokensList
TokenNode curr_node = res.head;
OPENVINO_ASSERT(curr_node != nullptr);
TokenNode next_node = curr_node->next;

// replacement sequence number, is used in CompareRank.
// When merges have the same position prefer replaces which occured earlier.
int32_t i = 0;
while (next_node) {
auto pair = std::make_pair(curr_node->data, next_node->data);
if (m_merges.count(pair)) {
auto [idx, rank] = m_merges.at(pair);
pq.emplace(idx, rank, curr_node, next_node, i);
}
res.erase(res.begin() + idx, res.begin() + idx + 2);
res.insert(res.begin() + idx, m_merges.at(pair).second);
curr_node = next_node;
next_node = curr_node->next;
i++;
}

// Stored pairs which become invalid after merging neighbors.
std::unordered_set<std::pair<TokenNode, TokenNode>, NodePairHash, NodePairEqual> invalid_pairs;
while (!pq.empty() && res.size() >= 2) {
auto [idx, rank, first_it, second_it, position] = pq.top();
pq.pop();

// Check that pair is still valid, if not, then continue.
if (invalid_pairs.count({first_it, second_it})) {
continue;
}

// Mark old neighbors as invalid.
if (first_it != res.head) {
invalid_pairs.insert({first_it->prev, first_it});
}
if (second_it != res.tail) {
invalid_pairs.insert({second_it, second_it->next});
}

// Merge the pair.
auto new_node = res.merge_neighbors(first_it, second_it, rank);

// Need to update the priority queue for the pairs which appeared after merge.
if (first_it->prev) {
auto prev_pair = std::make_pair(first_it->prev->data, new_node->data);

if (m_merges.count(prev_pair)) {
auto [idx, rank] = m_merges.at(prev_pair);
pq.emplace(idx, rank, first_it->prev, new_node, i);
}
}

if (second_it->next) {
auto next_pair = std::make_pair(new_node->data, second_it->next->data);

if (m_merges.count(next_pair)) {
auto [idx, rank] = m_merges.at(next_pair);
pq.emplace(idx, rank, new_node, second_it->next, i);
}
}
i++;
}

auto last_pair = std::make_pair(256, 260);
if (m_merges.count(last_pair)) {
auto last_found_rank = m_merges.at(last_pair).second;
}
std::vector<int32_t> res_vec;
res_vec.reserve(res.size());
TokenNode node = res.head;
while (node) {
res_vec.emplace_back(node->data);
node = node->next;
}

// TODO: Check if LRU Cache is more effective.
if (m_cache.size() < m_cache_capacity && initial_num_tokens > 2) {
m_cache.insert({text, res});
m_cache.insert({text, res_vec});
}
return res;
return res_vec;
}

BPETokenizerImpl::BPETokenizerImpl(
Expand Down
90 changes: 87 additions & 3 deletions src/bpe_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,92 @@
using TextMerges = std::vector<std::pair<std::string, std::string>>;
using Merges = std::map<std::pair<int32_t, int32_t>, std::pair<int32_t, int32_t>>;
using Vocab = std::unordered_map<std::string, unsigned int>;
using Tokens = std::vector<int32_t>;

template <typename T = int32_t>
class TokensList {
public:
struct Node {
T data;
std::shared_ptr<Node> prev;
std::shared_ptr<Node> next;
Node(const T& data) : data(data), prev(nullptr), next(nullptr) {}
};

size_t m_size;

public:
size_t size() const {
return m_size;
}

std::shared_ptr<Node> head;
std::shared_ptr<Node> tail;

TokensList() : head(nullptr), tail(nullptr), m_size(0) {}

~TokensList() {
while (head) {
head = head->next;
}
}

void insert(const T& data) {
std::shared_ptr<Node> new_node = std::make_shared<Node>(data);
if (!head) {
head = tail = new_node;
} else {
tail->next = new_node;
new_node->prev = tail;
tail = new_node;
}
m_size++;
}

std::shared_ptr<Node> merge_neighbors(std::shared_ptr<Node> first, std::shared_ptr<Node> second, const T& new_data) {
// OPENVINO_ASSERT(!first || !second || first->next != second);
// OPENVINO_THROW("Nodes must be consecutive and non-null");

std::shared_ptr<Node> new_node = std::make_shared<Node>(new_data);

new_node->prev = first->prev;
new_node->next = second->next;

if (first->prev) {
first->prev->next = new_node;
} else {
head = new_node;
}

if (second->next) {
second->next->prev = new_node;
} else {
tail = new_node;
}

// No need to delete first and second as shared_ptr will handle it
m_size -= 1;
return new_node;
}
};

// Define a custom hash function for std::pair
struct NodePairHash {
std::size_t operator()(const std::pair<std::shared_ptr<TokensList<int32_t>::Node>, std::shared_ptr<TokensList<int32_t>::Node>>& pair) const {
auto hash1 = std::hash<std::shared_ptr<TokensList<int32_t>::Node>>{}(pair.first);
auto hash2 = std::hash<std::shared_ptr<TokensList<int32_t>::Node>>{}(pair.second);
return hash1 ^ (hash2 << 1); // Combine the two hash values
}
};

// Define a custom equality function for std::pair
struct NodePairEqual {
bool operator()(const std::pair<std::shared_ptr<TokensList<int32_t>::Node>, std::shared_ptr<TokensList<int32_t>::Node>>& lhs,
const std::pair<std::shared_ptr<TokensList<int32_t>::Node>, std::shared_ptr<TokensList<int32_t>::Node>>& rhs) const {
return lhs.first == rhs.first && lhs.second == rhs.second;
}
};

using TokenNode = std::shared_ptr<TokensList<int32_t>::Node>;

class BPETokenizerImpl {
private:
Expand All @@ -32,7 +117,6 @@ class BPETokenizerImpl {
bool m_fuse_unk = false;
size_t m_cache_capacity;
std::unordered_map<std::string, std::vector<int32_t>> m_cache;
std::pair<std::pair<int32_t, int32_t>, size_t> get_min_rank_pair(Tokens tokens);
public:
BPETokenizerImpl(Vocab vocab, Merges merges): m_vocab(vocab), m_merges(merges) {};
BPETokenizerImpl(
Expand All @@ -44,7 +128,7 @@ class BPETokenizerImpl {
bool fuse_unk = false,
bool byte_fallback = false
);
Tokens tokenize(std::string& text);
std::vector<int32_t> tokenize(std::string& text);
};


Expand Down
2 changes: 1 addition & 1 deletion src/regex_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ bool RegexNormalization::evaluate(ov::TensorVector& outputs, const ov::TensorVec
m_replace_pattern = std::string(inputs[pattern_input + 1].data<const char>(), inputs[pattern_input + 1].get_size());
m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
}

return evaluate_normalization_helper(
outputs, inputs,
[this](const std::string& str) -> std::string {
Expand Down

0 comments on commit 17fc3c1

Please sign in to comment.