From 81d83361d0bfc466616d65f3eff91d723cc48630 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 5 Oct 2024 17:58:22 +0200 Subject: [PATCH] fix the unigram::from calls --- tokenizers/src/models/unigram/model.rs | 9 ++++++--- tokenizers/src/models/unigram/serialization.rs | 10 ++++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 4a5371738..c604b11c6 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -548,7 +548,8 @@ mod tests { ("abcd".to_string(), 10.0), ]; - let model = Unigram::from(sentencepieces, Some(0), false).unwrap(); + let model = + Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let result = model.encode("abcd").unwrap(); assert_eq!(result, vec!["abcd"]); } @@ -570,7 +571,8 @@ mod tests { ("qr".to_string(), -0.5), ]; - let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap(); + let mut model = + Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap(); for is_optimized in &[true, false] { model.set_optimized(*is_optimized); @@ -617,7 +619,8 @@ mod tests { ("<0xC3>".to_string(), -0.01), ("<0xA9>".to_string(), -0.03), ]; - let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap(); + let unigram = + Unigram::from(sentencepieces, Some(0), true, &AddedVocabulary::default()).unwrap(); let tokens: Vec = unigram.tokenize("é").unwrap(); assert_eq!( tokens, diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index 1ad95002e..f0ff30694 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -1,3 +1,5 @@ +use crate::AddedVocabulary; + use super::model::Unigram; use serde::{ de::{Error, MapAccess, Visitor}, @@ -69,8 +71,12 @@ impl<'de> Visitor<'de> for UnigramVisitor { } } match (vocab, unk_id, byte_fallback) { - (Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback) - .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?), + (Some(vocab), unk_id, byte_fallback) => { + Ok( + Unigram::from(vocab, unk_id, byte_fallback, &AddedVocabulary::default()) + .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?, + ) + } (None, _, _) => Err(Error::custom("Missing vocab")), } }