Skip to content

Commit

Permalink
current update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jul 12, 2024
1 parent 0dd9462 commit ae66fda
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 101 deletions.
2 changes: 1 addition & 1 deletion bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ impl PyTokenizer {
/// Returns:
/// :obj:`int`: The number of tokens that were created in the vocabulary
#[pyo3(text_signature = "(self, old_tokens, new_tokens)")]
fn assing_tokens(
fn assign_tokens(
&mut self,
old_tokens: &Bound<'_, PyList>,
new_tokens: &Bound<'_, PyList>,
Expand Down
135 changes: 35 additions & 100 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use super::{
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
use regex::Regex;
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
use std::collections::{HashMap, HashSet};

use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
/// Represent a token added by the user on top of the existing Model vocabulary.
/// AddedToken can be configured to specify the behavior they should have in various situations
/// like:
Expand Down Expand Up @@ -142,19 +144,12 @@ fn space_rightmost_at_start(sentence: &str) -> usize {
pub struct AddedVocabulary {
/// Contains the mapping from String (token content) to ID. This map contains both special
/// tokens and classic added tokens that were added to the this vocabulary.
added_tokens_map: HashMap<String, u32>,
added_tokens_map: Arc<Mutex<HashMap<String, u32>>>,
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
/// and classic.
added_tokens_map_r: HashMap<u32, AddedToken>,

added_tokens_map_r: Arc<Mutex<HashMap<u32, AddedToken>>>,
/// Contains only the classic AddedToken, in the specific order the user gave them.
added_tokens: Vec<AddedToken>,
/// Contains only the special AddedToken, in the specific order the user gave them.
special_tokens: Vec<AddedToken>,

/// A Set, containing all the special token for easy access while decoding. This let's
/// us remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>,

/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie: MatchingSet,
Expand All @@ -176,11 +171,9 @@ impl AddedVocabulary {
.build::<_, &&[u8]>([])
.expect("The normalized trie should build correctly");
Self {
added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(),
added_tokens_map: Arc::new(Mutex::new(HashMap::new())),
added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())),
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
encode_special_tokens: false,
Expand All @@ -189,46 +182,36 @@ impl AddedVocabulary {
/// Size of the additional vocabulary
#[allow(dead_code)] // Suppress the "method is never used" warning
pub fn len(&self) -> usize {
self.added_tokens_map.len()
self.added_tokens_map.lock().unwrap().len()
}

/// Whether or not this vocabulary is empty
pub fn is_empty(&self) -> bool {
self.added_tokens_map.is_empty()
self.added_tokens_map.lock().unwrap().is_empty()
}

/// Get the additional vocabulary
pub fn get_vocab(&self) -> &HashMap<String, u32> {
&self.added_tokens_map
pub fn get_vocab(&self) -> HashMap<String, u32> {
self.added_tokens_map.lock().unwrap().clone()
}

/// Get the additional vocabulary with the AddedTokens
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
&self.added_tokens_map_r
pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
self.added_tokens_map_r.lock().unwrap().clone()
}

/// Get the id matching one of our token if it exists
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
self.added_tokens_map
.get(token)
.copied()
.or_else(|| model.token_to_id(token))
let added_tokens_map = self.added_tokens_map.lock().unwrap();
let id = added_tokens_map.get(token).copied();
id.or_else(|| model.token_to_id(token))
}

/// Get the token matching the given id if it exists
#[deprecated(
since = "0.19.0",
note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
)]
pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
self.added_tokens_map_r
.get(&id)
.map(|t| t.content.clone())
.or_else(|| model.id_to_token(id))
}

pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
self.added_tokens_map_r.get(&id).map(|t| t.content.clone())
}

//
Expand All @@ -253,6 +236,11 @@ impl AddedVocabulary {
normalizer: Option<&N>,
) -> usize {
self.add_tokens(tokens, model, normalizer)
/// Get the token matching the given id if it exists
pub fn simple_id_to_token(&self, id: &u32) -> Option<String> {
let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
let token = added_tokens_map_r.get(id).map(|t| t.content.clone());
token
}

/// Re assigns a token's content to a new content. This helps users how want to
Expand All @@ -276,69 +264,16 @@ impl AddedVocabulary {
}
}

/// Add some tokens to the vocabulary
pub fn add_tokens<N: Normalizer>(
&mut self,
tokens: &[AddedToken],
model: &impl Model,
normalizer: Option<&N>,
) -> usize {
// Handle special tokens (if any)
for token in tokens {
if token.special
&& !token.content.is_empty()
&& !self.special_tokens_set.contains(&token.content)
{
self.special_tokens.push(token.to_owned());
self.special_tokens_set.insert(token.content.clone());
}
}

// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
let mut ignored = 0;
for token in tokens {
if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
{
ignored += 1;
continue;
}
// If a token is already part of the vocabulary, we mark it as added
let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
new_id
} else {
self.added_tokens_map.values().cloned().max().map_or(
model.get_vocab_size() as u32,
|max| {
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
max + 1
} else {
model.get_vocab_size() as u32
}
},
)
};
// Make sure we modify the previous entry
self.added_tokens_map
.entry(token.content.clone())
.and_modify(|old_id| *old_id = new_id)
.or_insert_with(|| new_id);
// Update the current revert operation
self.added_tokens_map_r
.entry(new_id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
// Make sure to remove previous entry (if the token gets a new id)

// Finally add the token to the classic set if special
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
}
}
/// Add a token to the added vocabulary
pub fn add_token(&mut self, token: &AddedToken) {
let mut added_tokens_map = self.added_tokens_map.lock().unwrap();
let mut added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();

self.refresh_added_tokens(model, normalizer);
let id = added_tokens_map.len() as u32;
added_tokens_map.insert(token.content.clone(), id);
added_tokens_map_r.insert(id, token.clone());

// Return the number of added tokens
tokens.len() - ignored
self.refresh_added_tokens();
}

/// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
Expand All @@ -348,9 +283,8 @@ impl AddedVocabulary {
fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
type TupleTokenId<'a> = (&'a AddedToken, u32);
let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
.special_tokens
.added_tokens
.iter()
.chain(self.added_tokens.iter())
.map(|token| {
(
token,
Expand Down Expand Up @@ -402,10 +336,9 @@ impl AddedVocabulary {
let mut stop = mat.end();
let aho_id = mat.pattern();
let id = split_re.1[aho_id];
let added_token = &self.added_tokens_map_r.get(&id).unwrap();
let added_token = self.added_tokens_map_r.lock().unwrap().get(&id).unwrap();

if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
{
if self.encode_special_tokens && added_token.special {
continue;
}

Expand Down Expand Up @@ -543,6 +476,8 @@ impl Serialize for AddedVocabulary {
{
let mut added_tokens = self
.added_tokens_map_r
.lock()
.unwrap()
.iter()
.map(|(id, token)| AddedTokenWithId {
id: *id,
Expand Down

0 comments on commit ae66fda

Please sign in to comment.