Skip to content

Commit

Permalink
More cache options. (#1675)
Browse files Browse the repository at this point in the history
* More cache options.

* Fixing error messages.
  • Loading branch information
Narsil authored Nov 6, 2024
1 parent 1740bff commit c6b5c3e
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 0 deletions.
48 changes: 48 additions & 0 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,30 @@ impl PyBPE {
)?,
)
}

/// Clears the internal cache
#[pyo3(signature = ())]
#[pyo3(text_signature = "(self)")]
fn _clear_cache(self_: PyRef<Self>) -> PyResult<()> {
let super_ = self_.as_ref();
let mut model = super_.model.write().map_err(|e| {
exceptions::PyException::new_err(format!("Error while clearing BPE cache: {}", e))
})?;
model.clear_cache();
Ok(())
}

/// Resize the internal cache
#[pyo3(signature = (capacity))]
#[pyo3(text_signature = "(self, capacity)")]
fn _resize_cache(self_: PyRef<Self>, capacity: usize) -> PyResult<()> {
let super_ = self_.as_ref();
let mut model = super_.model.write().map_err(|e| {
exceptions::PyException::new_err(format!("Error while resizing BPE cache: {}", e))
})?;
model.resize_cache(capacity);
Ok(())
}
}

/// An implementation of the WordPiece algorithm
Expand Down Expand Up @@ -858,6 +882,30 @@ impl PyUnigram {
)),
}
}

/// Clears the internal cache
#[pyo3(signature = ())]
#[pyo3(text_signature = "(self)")]
fn _clear_cache(self_: PyRef<Self>) -> PyResult<()> {
let super_ = self_.as_ref();
let mut model = super_.model.write().map_err(|e| {
exceptions::PyException::new_err(format!("Error while clearing Unigram cache: {}", e))
})?;
model.clear_cache();
Ok(())
}

/// Resize the internal cache
#[pyo3(signature = (capacity))]
#[pyo3(text_signature = "(self, capacity)")]
fn _resize_cache(self_: PyRef<Self>, capacity: usize) -> PyResult<()> {
let super_ = self_.as_ref();
let mut model = super_.model.write().map_err(|e| {
exceptions::PyException::new_err(format!("Error while resizing Unigram cache: {}", e))
})?;
model.resize_cache(capacity);
Ok(())
}
}

/// Models Module
Expand Down
7 changes: 7 additions & 0 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ impl BPE {
}
}

/// Resize the cache
pub fn resize_cache(&mut self, capacity: usize) {
if let Some(ref mut cache) = self.cache {
cache.resize(capacity);
}
}

pub fn get_vocab(&self) -> Vocab {
self.vocab.clone()
}
Expand Down
17 changes: 17 additions & 0 deletions tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ impl Model for ModelWrapper {
}
}

impl ModelWrapper {
pub fn clear_cache(&mut self) {
match self {
Self::Unigram(model) => model.clear_cache(),
Self::BPE(model) => model.clear_cache(),
_ => (),
}
}
pub fn resize_cache(&mut self, capacity: usize) {
match self {
Self::Unigram(model) => model.resize_cache(capacity),
Self::BPE(model) => model.resize_cache(capacity),
_ => (),
}
}
}

#[derive(Clone, Serialize, Deserialize)]
pub enum TrainerWrapper {
BpeTrainer(BpeTrainer),
Expand Down
10 changes: 10 additions & 0 deletions tokenizers/src/models/unigram/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,16 @@ impl Unigram {
let string = read_to_string(path)?;
Ok(serde_json::from_str(&string)?)
}

/// Clears the internal cache
pub fn clear_cache(&mut self) {
self.cache.clear();
}

/// Resize the cache
pub fn resize_cache(&mut self, capacity: usize) {
self.cache.resize(capacity);
}
}

/// Iterator to iterate of vocabulary of the model, and their relative score.
Expand Down
7 changes: 7 additions & 0 deletions tokenizers/src/utils/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,11 @@ where
pub(crate) fn set(&self, key: K, value: V) {
self.set_values(std::iter::once((key, value)))
}

pub(crate) fn resize(&mut self, capacity: usize) {
self.capacity = capacity;
if let Ok(mut cache) = self.map.try_write() {
cache.shrink_to(capacity);
}
}
}

0 comments on commit c6b5c3e

Please sign in to comment.