diff --git a/benches/tokenize.rs b/benches/tokenize.rs index 5148b94..e72c1dd 100644 --- a/benches/tokenize.rs +++ b/benches/tokenize.rs @@ -128,7 +128,7 @@ fn tokenize_with_hf(tokenizer: &HFTokenizer, seq: &str) -> Vec { } fn tokenize_with_sam>( - tokenizer: &GreedyTokenizer, + tokenizer: &GreedyTokenizer>, seq: &str, ) -> Vec { tokenizer diff --git a/src/sam/mod.rs b/src/sam/mod.rs index 9827ea0..38cb967 100644 --- a/src/sam/mod.rs +++ b/src/sam/mod.rs @@ -114,11 +114,14 @@ impl GeneralSAM { self.node_pool.get(node_id) } - pub fn get_root_state(&self) -> GeneralSAMState { + pub fn get_root_state(&self) -> GeneralSAMState> { self.get_state(SAM_ROOT_NODE_ID) } - pub fn get_state(&self, node_id: GeneralSAMNodeID) -> GeneralSAMState { + pub fn get_state( + &self, + node_id: GeneralSAMNodeID, + ) -> GeneralSAMState> { if node_id < self.node_pool.len() { GeneralSAMState { sam: self, node_id } } else { diff --git a/src/sam/state.rs b/src/sam/state.rs index 750e528..1a75cdd 100644 --- a/src/sam/state.rs +++ b/src/sam/state.rs @@ -1,37 +1,59 @@ //! States of a general suffix automaton. +use std::ops::Deref; + use crate::{TravelEvent, TrieNodeAlike}; use super::{GeneralSAM, GeneralSAMNode, TransitionTable, SAM_NIL_NODE_ID, SAM_ROOT_NODE_ID}; #[derive(Debug)] -pub struct GeneralSAMState<'s, TransTable: TransitionTable> { - pub sam: &'s GeneralSAM, +pub struct GeneralSAMState< + TransTable: TransitionTable, + SAMRef: Deref>, +> { + pub sam: SAMRef, pub node_id: usize, } -impl<'s, TransTable: TransitionTable> Clone for GeneralSAMState<'s, TransTable> { +impl> + Clone> Clone + for GeneralSAMState +{ fn clone(&self) -> Self { Self { - sam: self.sam, + sam: self.sam.clone(), node_id: self.node_id, } } } -impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { - pub fn feed_bytes(self, seq: &'s str) -> Self { +impl, SAMRef: Deref>> + GeneralSAMState +{ + pub fn feed_bytes(self, seq: &str) -> Self { self.feed_ref(seq.as_bytes()) } } -impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { +impl< + TransTable: TransitionTable, + SAMRef: Deref>, + > GeneralSAMState +{ pub fn feed_chars(self, seq: &str) -> Self { self.feed(seq.chars()) } } -impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { +impl>> + GeneralSAMState +{ + pub fn inner_as_ref(&self) -> GeneralSAMState> { + GeneralSAMState { + sam: &self.sam, + node_id: self.node_id, + } + } + pub fn is_nil(&self) -> bool { self.node_id == SAM_NIL_NODE_ID } @@ -46,10 +68,8 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { .unwrap_or(false) } - pub fn get_non_nil_trans(&self, key: &TransTable::KeyType) -> Option { - self.get_node() - .and_then(|node| node.trans.get(key)) - .map(|x| self.sam.get_state(*x)) + pub fn get_sam_ref(&self) -> &GeneralSAM { + &self.sam } pub fn get_node(&self) -> Option<&GeneralSAMNode> { @@ -86,17 +106,21 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { } self } -} -impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { - pub fn feed_ref>(self, seq: Seq) -> Self { + pub fn feed_ref<'s, Seq: IntoIterator>(self, seq: Seq) -> Self + where + ::KeyType: 's, + { self.feed_ref_iter(seq.into_iter()) } - pub fn feed_ref_iter>( + pub fn feed_ref_iter<'s, Iter: Iterator>( mut self, iter: Iter, - ) -> Self { + ) -> Self + where + ::KeyType: 's, + { for t in iter { if self.is_nil() { break; @@ -105,21 +129,36 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { } self } +} + +impl> + Clone> + GeneralSAMState +{ + pub fn get_non_nil_trans(&self, key: &TransTable::KeyType) -> Option { + self.get_node() + .and_then(|node| node.trans.get(key)) + .map(|x| Self { + sam: self.sam.clone(), + node_id: *x, + }) + } fn wrap_travel_along_callback< + 's, TN: TrieNodeAlike, ExtraType, ErrorType, F: 's + FnMut( - TravelEvent<(&GeneralSAMState, &TN), ExtraType, TN::InnerType>, + TravelEvent<(&Self, &TN), ExtraType, TN::InnerType>, ) -> Result, >( &'s self, mut callback: F, ) -> impl FnMut( - TravelEvent<&TN, (GeneralSAMState<'s, TransTable>, ExtraType), TN::InnerType>, - ) -> Result<(GeneralSAMState<'s, TransTable>, ExtraType), ErrorType> { + TravelEvent<&TN, (Self, ExtraType), TN::InnerType>, + ) -> Result<(Self, ExtraType), ErrorType> + + 's { move |event| match event { TravelEvent::PushRoot(trie_root) => { let res = callback(TravelEvent::PushRoot((self, trie_root)))?; @@ -143,9 +182,7 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { TN: TrieNodeAlike + Clone, ExtraType, ErrorType, - F: FnMut( - TravelEvent<(&GeneralSAMState, &TN), ErrorType, TN::InnerType>, - ) -> Result, + F: FnMut(TravelEvent<(&Self, &TN), ErrorType, TN::InnerType>) -> Result, >( &self, trie_node: TN, @@ -158,9 +195,7 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> { TN: TrieNodeAlike, ExtraType, ErrorType, - F: FnMut( - TravelEvent<(&GeneralSAMState, &TN), ErrorType, TN::InnerType>, - ) -> Result, + F: FnMut(TravelEvent<(&Self, &TN), ErrorType, TN::InnerType>) -> Result, >( &self, trie_node: TN, diff --git a/src/tests/utils.rs b/src/tests/utils.rs index d69eed5..535a37e 100644 --- a/src/tests/utils.rs +++ b/src/tests/utils.rs @@ -72,7 +72,7 @@ fn test_rope() { #[cfg(feature = "trie")] mod trie { - use std::collections::BTreeMap; + use std::{collections::BTreeMap, ops::Deref}; use rand::{ distributions::{Alphanumeric, DistString}, @@ -129,8 +129,9 @@ mod trie { T: Clone, TransTable: TransitionTable, Iter: Iterator, + SAMRef: Deref>, >( - tokenizer: &GreedyTokenizer, + tokenizer: &GreedyTokenizer, trie: &Trie, seq: Iter, ) { @@ -188,6 +189,32 @@ mod trie { case_tokenizer(&tokenizer, &trie, "abc".bytes()); } + #[test] + fn test_tokenizer_owning_sam() { + let vocab = [ + "a", "ab", "b", "bc", "c", "d", "e", "f", "cd", "abcde", "你好", "🧡", + ]; + let mut trie = Trie::>::default(); + let mut id_to_word = BTreeMap::new(); + for word in vocab { + id_to_word.insert(trie.insert_iter(word.bytes()), word); + } + + let sam = GeneralSAM::>::from_trie(trie.get_root_state()); + + let tokenizer = GreedyTokenizer::, _, _>::build_from_sam_and_trie( + sam, + trie.get_root_state(), + ); + + case_tokenizer(&tokenizer, &trie, "abcde".bytes()); + case_tokenizer(&tokenizer, &trie, "abcdf".bytes()); + case_tokenizer(&tokenizer, &trie, "abca".bytes()); + case_tokenizer(&tokenizer, &trie, "Hi,你好吗?".bytes()); + case_tokenizer(&tokenizer, &trie, "🧡🧡🧡🧡🧡!".bytes()); + case_tokenizer(&tokenizer, &trie, "abc".bytes()); + } + fn case_tokenizer_vocab< T: Clone + Ord + Eq + std::hash::Hash, TransTable: TransitionTable, diff --git a/src/trie.rs b/src/trie.rs index cb3b296..25c254d 100644 --- a/src/trie.rs +++ b/src/trie.rs @@ -1,5 +1,7 @@ //! Trie, supporting `TrieNodeAlike`. +use std::ops::Deref; + use crate::{ConstructiveTransitionTable, GeneralSAMNodeID, TransitionTable, TrieNodeAlike}; pub type TrieNodeID = GeneralSAMNodeID; @@ -18,12 +20,23 @@ pub struct Trie { node_pool: Vec>, } -#[derive(Clone, Debug)] -pub struct TrieState<'s, TransTable: TransitionTable> { - pub trie: &'s Trie, +#[derive(Debug)] +pub struct TrieState>> { + pub trie: TrieRef, pub node_id: TrieNodeID, } +impl> + Clone> Clone + for TrieState +{ + fn clone(&self) -> Self { + Self { + trie: self.trie.clone(), + node_id: self.node_id, + } + } +} + impl TrieNode { fn new(parent: TrieNodeID) -> Self { Self { @@ -70,7 +83,7 @@ impl Trie { self.node_pool.len() } - pub fn get_state(&self, node_id: TrieNodeID) -> TrieState { + pub fn get_state(&self, node_id: TrieNodeID) -> TrieState> { if node_id >= self.node_pool.len() { return TrieState { trie: self, @@ -91,7 +104,7 @@ impl Trie { self.get_node(TRIE_ROOT_NODE_ID).unwrap() } - pub fn get_root_state(&self) -> TrieState { + pub fn get_root_state(&self) -> TrieState> { self.get_state(TRIE_ROOT_NODE_ID) } @@ -142,7 +155,16 @@ impl Trie { } } -impl<'s, TransTable: TransitionTable> TrieState<'s, TransTable> { +impl>> + TrieState +{ + pub fn inner_as_ref(&self) -> TrieState> { + TrieState { + trie: &self.trie, + node_id: self.node_id, + } + } + pub fn is_nil(&self) -> bool { self.node_id == TRIE_NIL_NODE_ID } @@ -151,7 +173,7 @@ impl<'s, TransTable: TransitionTable> TrieState<'s, TransTable> { self.node_id == TRIE_ROOT_NODE_ID } - pub fn get_node(&self) -> Option<&'s TrieNode> { + pub fn get_node(&self) -> Option<&TrieNode> { self.trie.get_node(self.node_id) } @@ -170,15 +192,28 @@ impl<'s, TransTable: TransitionTable> TrieState<'s, TransTable> { self.node_id = TRIE_NIL_NODE_ID; } } + + pub fn feed_iter>(&mut self, iter: Iter) { + iter.for_each(|x| self.goto(&x)); + } + + pub fn feed_ref_iter<'s, Iter: Iterator>( + &'s mut self, + iter: Iter, + ) { + iter.for_each(|x| self.goto(x)); + } } #[derive(Clone, Debug)] pub struct NextTrieStateIter<'s, TransTable: TransitionTable> { - state: TrieState<'s, TransTable>, + trie: &'s Trie, iter: TransTable::IterType<'s>, } -impl<'s, TransTable: TransitionTable> TrieNodeAlike for TrieState<'s, TransTable> { +impl<'s, TransTable: TransitionTable> TrieNodeAlike + for TrieState> +{ type InnerType = TransTable::KeyType; type NextStateIter = NextTrieStateIter<'s, TransTable>; @@ -187,23 +222,23 @@ impl<'s, TransTable: TransitionTable> TrieNodeAlike for TrieState<'s, TransTable } fn next_states(self) -> Self::NextStateIter { - let iter = self.get_node().unwrap().trans.iter(); - NextTrieStateIter { state: self, iter } + let iter = self.trie.get_node(self.node_id).unwrap().trans.iter(); + NextTrieStateIter { + trie: self.trie, + iter, + } } } impl<'s, TransTable: TransitionTable> Iterator for NextTrieStateIter<'s, TransTable> { - type Item = (TransTable::KeyType, TrieState<'s, TransTable>); + type Item = ( + TransTable::KeyType, + TrieState>, + ); fn next(&mut self) -> Option { self.iter .next() - .map(|(t, next_node_id)| (t.clone(), self.state.trie.get_state(*next_node_id))) - } -} - -impl<'s, TransTable: TransitionTable> NextTrieStateIter<'s, TransTable> { - pub fn get_state(&self) -> &TrieState { - &self.state + .map(|(t, next_node_id)| (t.clone(), self.trie.get_state(*next_node_id))) } } diff --git a/src/trie_alike.rs b/src/trie_alike.rs index 2ab1a7b..a4cc595 100644 --- a/src/trie_alike.rs +++ b/src/trie_alike.rs @@ -3,6 +3,7 @@ use std::collections::VecDeque; +#[derive(Clone, Debug)] pub enum TravelEvent<'s, NodeType, ExtraType, KeyType> { PushRoot(NodeType), Push(NodeType, &'s ExtraType, KeyType), diff --git a/src/utils/suffixwise.rs b/src/utils/suffixwise.rs index 5a2891e..93bb3da 100644 --- a/src/utils/suffixwise.rs +++ b/src/utils/suffixwise.rs @@ -137,7 +137,7 @@ impl SuffixInTrieData { ) -> Vec { let mut sam_to_data = vec![LinkedList::>::new(); sam.num_of_nodes()]; let callback = - |event: TravelEvent<(&GeneralSAMState<_>, &TN), _, _>| -> Result<_, Infallible> { + |event: TravelEvent<(&GeneralSAMState<_, &GeneralSAM<_>>, &TN), _, _>| -> Result<_, Infallible> { match event { crate::TravelEvent::Pop((sam_state, trie_state), len) => { if trie_state.is_accepting() { diff --git a/src/utils/tokenize.rs b/src/utils/tokenize.rs index e3439b1..13e5b43 100644 --- a/src/utils/tokenize.rs +++ b/src/utils/tokenize.rs @@ -1,6 +1,6 @@ //! Greedy tokenizer. -use std::ops::{AddAssign, SubAssign}; +use std::ops::{AddAssign, Deref, SubAssign}; use crate::{GeneralSAM, GeneralSAMState, TransitionTable, TrieNodeAlike}; @@ -20,28 +20,74 @@ use super::suffixwise::SuffixInTrieData; /// will be further merged in the ropes of its successors. #[derive(Clone, Debug)] pub struct GreedyTokenizer< - 's, TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq, + SAMRef: Deref>, > { - sam: &'s GeneralSAM, + sam: SAMRef, suffix_data: Vec>, } -impl<'s, TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq> - GreedyTokenizer<'s, TransTable, TokenIDType> +pub struct OwnedGeneralSAM { + pub sam: GeneralSAM, +} + +impl Deref for OwnedGeneralSAM { + type Target = GeneralSAM; + + fn deref(&self) -> &Self::Target { + &self.sam + } +} + +impl + GreedyTokenizer> +{ + pub fn build_from_sam< + TN: TrieNodeAlike, + F: FnMut(&TN) -> TokenIDType, + >( + sam: GeneralSAM, + trie_node: TN, + f: F, + ) -> Self { + Self { + suffix_data: SuffixInTrieData::build(&sam, trie_node, f), + sam: OwnedGeneralSAM { sam }, + } + } +} + +impl< + TransTable: TransitionTable, + TokenIDType: Clone + Default + PartialEq, + SAMRef: Deref>, + > GreedyTokenizer { + pub fn get_sam_ref(&self) -> &GeneralSAM { + &self.sam + } + + pub fn inner_as_ref( + &self, + ) -> GreedyTokenizer> { + GreedyTokenizer { + sam: &self.sam, + suffix_data: self.suffix_data.clone(), + } + } + pub fn build< TN: TrieNodeAlike, F: FnMut(&TN) -> TokenIDType, >( - sam: &'s GeneralSAM, + sam: SAMRef, trie_node: TN, f: F, ) -> Self { Self { + suffix_data: SuffixInTrieData::build(sam.deref(), trie_node, f), sam, - suffix_data: SuffixInTrieData::build(sam, trie_node, f), } } @@ -62,22 +108,24 @@ impl<'s, TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq> res.push((token_id, token_len)) }; - let pop_buffer = - |cur_len: &mut usize, cur_state: &mut GeneralSAMState, res: &mut Vec<_>| { - let inner_data = self.suffix_data[cur_state.node_id] - .get(*cur_len) - .expect("invalid state"); - - // TODO: optimize for unknown token: - // find the lower bound position where the suffix is prefixed with a token - let (token_id, token_len) = inner_data.as_ref().map_or_else( - || (unk_token_id, 1), - |token_info| (&token_info.digested_trie_node, token_info.seq_len), - ); - - cur_len.sub_assign(token_len); - push(res, token_id.clone(), token_len); - }; + let pop_buffer = |cur_len: &mut usize, + cur_state: &mut GeneralSAMState>, + res: &mut Vec<_>| { + let inner_data = self.suffix_data[cur_state.node_id] + .get(*cur_len) + .expect("invalid state"); + + // TODO: Optimize for unknown tokens: + // Find the lower bound position where the suffix is prefixed with a token. + // But this does not improve the time complexity, pending... + let (token_id, token_len) = inner_data.as_ref().map_or_else( + || (unk_token_id, 1), + |token_info| (&token_info.digested_trie_node, token_info.seq_len), + ); + + cur_len.sub_assign(token_len); + push(res, token_id.clone(), token_len); + }; let mut cur_state = self.sam.get_root_state(); let mut cur_len = 0; @@ -119,17 +167,34 @@ impl<'s, TransTable: TransitionTable, TokenIDType: Clone + Default + PartialEq> #[cfg(feature = "trie")] #[cfg_attr(doc_cfg, doc(cfg(feature = "trie")))] pub mod trie { + use std::ops::Deref; + use crate::{GeneralSAM, TransitionTable, Trie, TrieNodeAlike, TrieNodeID, TrieState}; - impl<'s, TransTable: TransitionTable> super::GreedyTokenizer<'s, TransTable, TrieNodeID> { - pub fn build_from_trie<'t, TT: TransitionTable>( - sam: &'s GeneralSAM, - trie_state: TrieState<'t, TT>, + use super::OwnedGeneralSAM; + + impl>> + super::GreedyTokenizer + { + pub fn build_from_trie>( + sam: SAMRef, + trie_state: TrieState>, ) -> Self { Self::build(sam, trie_state, |tn| tn.node_id) } } + impl + super::GreedyTokenizer> + { + pub fn build_from_sam_and_trie>( + sam: GeneralSAM, + trie_state: TrieState>, + ) -> Self { + Self::build_from_sam(sam, trie_state, |tn| tn.node_id) + } + } + /// Greedy tokenizer with a trie of the vocabulary. /// /// Assuming that the input length is $n$, the maximum word length is $l$, diff --git a/src/utils/treap.rs b/src/utils/treap.rs index 8c7bf3e..d50aff4 100644 --- a/src/utils/treap.rs +++ b/src/utils/treap.rs @@ -6,6 +6,7 @@ use rand::random; pub type NeedSwap = bool; +#[derive(Clone, Debug)] pub enum SplitTo { Left, Right,