From b5640a65cf59cf6c4ac2458dd01fc695cb0c7504 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Oct 2024 14:46:42 +0200 Subject: [PATCH] simplify the logic --- bindings/python/src/tokenizer.rs | 24 +++++--------------- tokenizers/src/tokenizer/added_vocabulary.rs | 15 ++++++------ tokenizers/src/tokenizer/mod.rs | 5 ++-- 3 files changed, 16 insertions(+), 28 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 9b0b82dcf..499cbd770 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1250,21 +1250,11 @@ impl PyTokenizer { /// Returns: /// :obj:`int`: The number of tokens that were created in the vocabulary #[pyo3(text_signature = "(self, old_tokens, new_tokens)")] - fn assign_tokens( - &mut self, - old_tokens: &Bound<'_, PyList>, - new_tokens: &Bound<'_, PyList>, - ) -> PyResult<()> { + fn assign_tokens(&mut self, old_to_new_map: &Bound<'_, PyDict>) -> PyResult<()> { use pyo3::exceptions::PyTypeError; - if old_tokens.len() != new_tokens.len() { - return Err(PyTypeError::new_err( - "old_tokens and new_tokens must have the same length", - )); - } - let mut processed_old_tokens = Vec::with_capacity(old_tokens.len()); - let mut processed_new_tokens = Vec::with_capacity(new_tokens.len()); - for (old, new) in old_tokens.iter().zip(new_tokens.iter()) { + let mut processed_old_tokens = HashMap::with_capacity(old_to_new_map.len()); + for (old, new) in old_to_new_map.iter() { let old_token = if let Ok(content) = old.extract::<&str>() { PyAddedToken::from(content.to_string(), Some(false)).get_token() } else if let Ok(token) = old.extract::>() { @@ -1287,12 +1277,10 @@ impl PyTokenizer { )); }; - processed_old_tokens.push(old_token); - processed_new_tokens.push(new_token); + processed_old_tokens.insert(old_token, new_token); } - Ok(self - .tokenizer - .assign_tokens(&processed_old_tokens, &processed_new_tokens)) + self.tokenizer.assign_tokens(&processed_old_tokens); + Ok(()) } /// Add the given special tokens to the Tokenizer. /// diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index d3ca1a484..6f79ba660 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -311,25 +311,26 @@ impl AddedVocabulary { /// use reserved tokens (which usually are in the original vocab, and in the added vocab) pub fn assign_tokens( &mut self, - old_token_content: &[AddedToken], - new_token_content: &[AddedToken], + token_map: &HashMap, // HashMap of old token to new token model: &impl Model, normalizer: Option<&N>, ) { - for (old, new) in old_token_content.iter().zip(new_token_content.iter()) { - if let Some(id) = self.token_to_id(old.content.as_str(), model) { + for (old_token, new_token) in token_map.iter() { + if let Some(id) = self.token_to_id(old_token.content.as_str(), model) { self.added_tokens_map_r .lock() .unwrap() .entry(id) - .and_modify(|t| t.content = new.content.clone()); + .and_modify(|t| *t = new_token.clone()); // Replace entire entry with new_token self.refresh_added_tokens(model, normalizer); } else { - error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone()) + error!( + "Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", + old_token.content.clone() + ) } } } - /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary. /// /// We keep two different RegexSet, one that will take care of matching against the diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index c6433dc43..c24654fc8 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -960,10 +960,9 @@ where } /// Assign a new token - pub fn assign_tokens(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) { + pub fn assign_tokens(&mut self, old_to_new_map: &HashMap) { self.added_vocabulary.assign_tokens( - old_tokens, - new_tokens, + old_to_new_map, // HashMap of old token to new token &self.model, self.normalizer.as_ref(), )