diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index dcaf95bea..715f9a264 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -537,8 +537,8 @@ impl PyBPE { fn clear_cache(self_: PyRef){ let super_ = self_.as_ref(); - let model = super_.model.read().unwrap(); - if let ModelWrapper::BPE(ref mo) = *model { + let mut model = super_.model.write().unwrap(); + if let ModelWrapper::Unigram(ref mut mo) = *model { mo.clear_cache() } else { unreachable!() @@ -868,6 +868,16 @@ impl PyUnigram { )), } } + + fn clear_cache(self_: PyRef){ + let super_ = self_.as_ref(); + let mut model = super_.model.write().unwrap(); + if let ModelWrapper::Unigram(ref mut mo) = *model { + mo.clear_cache() + } else { + unreachable!() + }; + } } /// Models Module diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index dba5a0400..8bafe81f4 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -377,6 +377,10 @@ impl Unigram { let string = read_to_string(path)?; Ok(serde_json::from_str(&string)?) } + + pub fn clear_cache(&mut self){ + self.cache.clear(); + } } /// Iterator to iterate of vocabulary of the model, and their relative score.