diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 91b8fe1bc..891609f23 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -534,6 +534,30 @@ impl PyBPE { )?, ) } + + /// Clears the internal cache + #[pyo3(signature = ())] + #[pyo3(text_signature = "(self)")] + fn _clear_cache(self_: PyRef) -> PyResult<()> { + let super_ = self_.as_ref(); + let mut model = super_.model.write().map_err(|e| { + exceptions::PyException::new_err(format!("Error while clearing BPE cache: {}", e)) + })?; + model.clear_cache(); + Ok(()) + } + + /// Resize the internal cache + #[pyo3(signature = (capacity))] + #[pyo3(text_signature = "(self, capacity)")] + fn _resize_cache(self_: PyRef, capacity: usize) -> PyResult<()> { + let super_ = self_.as_ref(); + let mut model = super_.model.write().map_err(|e| { + exceptions::PyException::new_err(format!("Error while resizing BPE cache: {}", e)) + })?; + model.resize_cache(capacity); + Ok(()) + } } /// An implementation of the WordPiece algorithm @@ -858,6 +882,30 @@ impl PyUnigram { )), } } + + /// Clears the internal cache + #[pyo3(signature = ())] + #[pyo3(text_signature = "(self)")] + fn _clear_cache(self_: PyRef) -> PyResult<()> { + let super_ = self_.as_ref(); + let mut model = super_.model.write().map_err(|e| { + exceptions::PyException::new_err(format!("Error while clearing Unigram cache: {}", e)) + })?; + model.clear_cache(); + Ok(()) + } + + /// Resize the internal cache + #[pyo3(signature = (capacity))] + #[pyo3(text_signature = "(self, capacity)")] + fn _resize_cache(self_: PyRef, capacity: usize) -> PyResult<()> { + let super_ = self_.as_ref(); + let mut model = super_.model.write().map_err(|e| { + exceptions::PyException::new_err(format!("Error while resizing Unigram cache: {}", e)) + })?; + model.resize_cache(capacity); + Ok(()) + } } /// Models Module diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 86fe74d50..df3841749 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -354,6 +354,13 @@ impl BPE { } } + /// Resize the cache + pub fn resize_cache(&mut self, capacity: usize) { + if let Some(ref mut cache) = self.cache { + cache.resize(capacity); + } + } + pub fn get_vocab(&self) -> Vocab { self.vocab.clone() } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index a6021d90e..3ab3b495b 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -207,6 +207,23 @@ impl Model for ModelWrapper { } } +impl ModelWrapper { + pub fn clear_cache(&mut self) { + match self { + Self::Unigram(model) => model.clear_cache(), + Self::BPE(model) => model.clear_cache(), + _ => (), + } + } + pub fn resize_cache(&mut self, capacity: usize) { + match self { + Self::Unigram(model) => model.resize_cache(capacity), + Self::BPE(model) => model.resize_cache(capacity), + _ => (), + } + } +} + #[derive(Clone, Serialize, Deserialize)] pub enum TrainerWrapper { BpeTrainer(BpeTrainer), diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index dba5a0400..b80fdaf43 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -377,6 +377,16 @@ impl Unigram { let string = read_to_string(path)?; Ok(serde_json::from_str(&string)?) } + + /// Clears the internal cache + pub fn clear_cache(&mut self) { + self.cache.clear(); + } + + /// Resize the cache + pub fn resize_cache(&mut self, capacity: usize) { + self.cache.resize(capacity); + } } /// Iterator to iterate of vocabulary of the model, and their relative score. diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index dceb58da8..8407c3620 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -115,4 +115,11 @@ where pub(crate) fn set(&self, key: K, value: V) { self.set_values(std::iter::once((key, value))) } + + pub(crate) fn resize(&mut self, capacity: usize) { + self.capacity = capacity; + if let Ok(mut cache) = self.map.try_write() { + cache.shrink_to(capacity); + } + } }