Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jul 12, 2024
1 parent b9481d4 commit 3d2705f
Showing 1 changed file with 67 additions and 2 deletions.
69 changes: 67 additions & 2 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,71 @@ impl AddedVocabulary {
) -> usize {
self.add_tokens(tokens, model, normalizer)
}

/// Add some tokens to the vocabulary
pub fn add_tokens<N: Normalizer>(
&mut self,
tokens: &[AddedToken],
model: &impl Model,
normalizer: Option<&N>,
) -> usize {
// Handle special tokens (if any)

// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
let mut ignored = 0;
for token in tokens {
if token.content.is_empty()
|| self
.added_tokens_map_r
.lock()
.unwrap()
.values()
.any(|val| val == token)
{
ignored += 1;
continue;
}
// If a token is already part of the vocabulary, we mark it as added
let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
new_id
} else {
self.added_tokens_map
.lock()
.unwrap()
.values()
.cloned()
.max()
.map_or(model.get_vocab_size() as u32, |max| {
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
max + 1
} else {
model.get_vocab_size() as u32
}
})
};
// Make sure we modify the previous entry
self.added_tokens_map
.lock()
.unwrap()
.entry(token.content.clone())
.and_modify(|old_id| *old_id = new_id)
.or_insert_with(|| new_id);
// Update the current revert operation
self.added_tokens_map_r
.lock()
.unwrap()
.entry(new_id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
// Make sure to remove previous entry (if the token gets a new id)
}

self.refresh_added_tokens(model, normalizer);

// Return the number of added tokens
tokens.len() - ignored
}

/// Get the token matching the given id if it exists
pub fn simple_id_to_token(&self, id: &u32) -> Option<String> {
let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
Expand Down Expand Up @@ -325,8 +390,8 @@ impl AddedVocabulary {
let mut stop = mat.end();
let aho_id = mat.pattern();
let id = split_re.1[aho_id];
let added_token = self.added_tokens_map_r.lock().unwrap().get(&id).unwrap();

let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
let added_token = added_tokens_map_r.get(&id).unwrap();
if self.encode_special_tokens && added_token.special {
continue;
}
Expand Down

0 comments on commit 3d2705f

Please sign in to comment.