From 5aa9f6cff07bab0f227c3a864bb47af9895700a1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 7 Nov 2024 21:36:27 +0800 Subject: [PATCH] Disable caching for long strings. (#1676) --- tokenizers/src/models/bpe/model.rs | 6 ++++-- tokenizers/src/models/unigram/model.rs | 6 ++++-- tokenizers/src/utils/cache.rs | 3 +++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index df3841749..217c37e90 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,6 +1,6 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; -use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; +use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; use serde_json::Value; use std::borrow::Cow; @@ -482,7 +482,9 @@ impl BPE { let word = self.merge_word(sequence)?; let ret = self.word_to_tokens(&word).collect(); if let Some(ref cache) = self.cache { - cache.set(sequence.to_owned(), word); + if sequence.len() < MAX_LENGTH { + cache.set(sequence.to_owned(), word); + } } Ok(ret) } diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index b80fdaf43..da4d631ce 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -4,7 +4,7 @@ use super::{ trie::{Trie, TrieBuilder}, }; use crate::tokenizer::{Model, Result, Token}; -use crate::utils::cache::Cache; +use crate::utils::cache::{Cache, MAX_LENGTH}; use std::collections::HashMap; use std::convert::TryInto; @@ -230,7 +230,9 @@ impl Unigram { } else { self.encode_unoptimized(sentence)? }; - self.cache.set(sentence.to_owned(), result.clone()); + if sentence.len() < MAX_LENGTH { + self.cache.set(sentence.to_owned(), result.clone()); + } Ok(result) } } diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 8407c3620..002fb1d61 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -5,6 +5,9 @@ use std::sync::RwLock; /// The default capacity for a `BPE`'s internal cache. pub static DEFAULT_CACHE_CAPACITY: usize = 10_000; +/// The maximum length we should cache in a model +/// Strings that are too long have minimal chances to cache hit anyway +pub static MAX_LENGTH: usize = 256; /// Provides a simple multithread cache to speed up BPE tokenization that will try to read values /// concurrently but won't block if another thread is writing.