From 39a29e2d3138bddeb569a9748c29cc2c9a3a19a7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Nov 2024 15:24:20 +0800 Subject: [PATCH] Disable caching for long strings. --- tokenizers/src/models/bpe/model.rs | 6 ++++-- tokenizers/src/models/unigram/model.rs | 6 ++++-- tokenizers/src/utils/cache.rs | 3 +++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 86fe74d50..9c57819ee 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,6 +1,6 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; -use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; +use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; use serde_json::Value; use std::borrow::Cow; @@ -475,7 +475,9 @@ impl BPE { let word = self.merge_word(sequence)?; let ret = self.word_to_tokens(&word).collect(); if let Some(ref cache) = self.cache { - cache.set(sequence.to_owned(), word); + if sequence.len() < MAX_LENGTH { + cache.set(sequence.to_owned(), word); + } } Ok(ret) } diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index dba5a0400..b744b523d 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -4,7 +4,7 @@ use super::{ trie::{Trie, TrieBuilder}, }; use crate::tokenizer::{Model, Result, Token}; -use crate::utils::cache::Cache; +use crate::utils::cache::{Cache, MAX_LENGTH}; use std::collections::HashMap; use std::convert::TryInto; @@ -230,7 +230,9 @@ impl Unigram { } else { self.encode_unoptimized(sentence)? }; - self.cache.set(sentence.to_owned(), result.clone()); + if sentence.len() < MAX_LENGTH { + self.cache.set(sentence.to_owned(), result.clone()); + } Ok(result) } } diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index dceb58da8..bba808d6e 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -5,6 +5,9 @@ use std::sync::RwLock; /// The default capacity for a `BPE`'s internal cache. pub static DEFAULT_CACHE_CAPACITY: usize = 10_000; +/// The maximum length we should cache in a model +/// Strings that are too long have minimal chances to cache hit anyway +pub static MAX_LENGTH: usize = 256; /// Provides a simple multithread cache to speed up BPE tokenization that will try to read values /// concurrently but won't block if another thread is writing.