Skip to content

Commit

Permalink
feat: use Deref instead of reference (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChieloNewctle authored Nov 27, 2023
1 parent de68a31 commit eb07841
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 78 deletions.
2 changes: 1 addition & 1 deletion benches/tokenize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn tokenize_with_hf(tokenizer: &HFTokenizer, seq: &str) -> Vec<u32> {
}

fn tokenize_with_sam<T: TransitionTable<KeyType = char>>(
tokenizer: &GreedyTokenizer<T, u32>,
tokenizer: &GreedyTokenizer<T, u32, &GeneralSAM<T>>,
seq: &str,
) -> Vec<u32> {
tokenizer
Expand Down
7 changes: 5 additions & 2 deletions src/sam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,14 @@ impl<TransTable: TransitionTable> GeneralSAM<TransTable> {
self.node_pool.get(node_id)
}

pub fn get_root_state(&self) -> GeneralSAMState<TransTable> {
pub fn get_root_state(&self) -> GeneralSAMState<TransTable, &GeneralSAM<TransTable>> {
self.get_state(SAM_ROOT_NODE_ID)
}

pub fn get_state(&self, node_id: GeneralSAMNodeID) -> GeneralSAMState<TransTable> {
pub fn get_state(
&self,
node_id: GeneralSAMNodeID,
) -> GeneralSAMState<TransTable, &GeneralSAM<TransTable>> {
if node_id < self.node_pool.len() {
GeneralSAMState { sam: self, node_id }
} else {
Expand Down
87 changes: 61 additions & 26 deletions src/sam/state.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,59 @@
//! States of a general suffix automaton.

use std::ops::Deref;

use crate::{TravelEvent, TrieNodeAlike};

use super::{GeneralSAM, GeneralSAMNode, TransitionTable, SAM_NIL_NODE_ID, SAM_ROOT_NODE_ID};

#[derive(Debug)]
pub struct GeneralSAMState<'s, TransTable: TransitionTable> {
pub sam: &'s GeneralSAM<TransTable>,
pub struct GeneralSAMState<
TransTable: TransitionTable,
SAMRef: Deref<Target = GeneralSAM<TransTable>>,
> {
pub sam: SAMRef,
pub node_id: usize,
}

impl<'s, TransTable: TransitionTable> Clone for GeneralSAMState<'s, TransTable> {
impl<TransTable: TransitionTable, SAMRef: Deref<Target = GeneralSAM<TransTable>> + Clone> Clone
for GeneralSAMState<TransTable, SAMRef>
{
fn clone(&self) -> Self {
Self {
sam: self.sam,
sam: self.sam.clone(),
node_id: self.node_id,
}
}
}

impl<'s, TransTable: TransitionTable<KeyType = u8>> GeneralSAMState<'s, TransTable> {
pub fn feed_bytes(self, seq: &'s str) -> Self {
impl<TransTable: TransitionTable<KeyType = u8>, SAMRef: Deref<Target = GeneralSAM<TransTable>>>
GeneralSAMState<TransTable, SAMRef>
{
pub fn feed_bytes(self, seq: &str) -> Self {
self.feed_ref(seq.as_bytes())
}
}

impl<'s, TransTable: TransitionTable<KeyType = char>> GeneralSAMState<'s, TransTable> {
impl<
TransTable: TransitionTable<KeyType = char>,
SAMRef: Deref<Target = GeneralSAM<TransTable>>,
> GeneralSAMState<TransTable, SAMRef>
{
pub fn feed_chars(self, seq: &str) -> Self {
self.feed(seq.chars())
}
}

impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
impl<TransTable: TransitionTable, SAMRef: Deref<Target = GeneralSAM<TransTable>>>
GeneralSAMState<TransTable, SAMRef>
{
pub fn inner_as_ref(&self) -> GeneralSAMState<TransTable, &GeneralSAM<TransTable>> {
GeneralSAMState {
sam: &self.sam,
node_id: self.node_id,
}
}

pub fn is_nil(&self) -> bool {
self.node_id == SAM_NIL_NODE_ID
}
Expand All @@ -46,10 +68,8 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
.unwrap_or(false)
}

pub fn get_non_nil_trans(&self, key: &TransTable::KeyType) -> Option<Self> {
self.get_node()
.and_then(|node| node.trans.get(key))
.map(|x| self.sam.get_state(*x))
pub fn get_sam_ref(&self) -> &GeneralSAM<TransTable> {
&self.sam
}

pub fn get_node(&self) -> Option<&GeneralSAMNode<TransTable>> {
Expand Down Expand Up @@ -86,17 +106,21 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
}
self
}
}

impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
pub fn feed_ref<Seq: IntoIterator<Item = &'s TransTable::KeyType>>(self, seq: Seq) -> Self {
pub fn feed_ref<'s, Seq: IntoIterator<Item = &'s TransTable::KeyType>>(self, seq: Seq) -> Self
where
<TransTable as TransitionTable>::KeyType: 's,
{
self.feed_ref_iter(seq.into_iter())
}

pub fn feed_ref_iter<Iter: Iterator<Item = &'s TransTable::KeyType>>(
pub fn feed_ref_iter<'s, Iter: Iterator<Item = &'s TransTable::KeyType>>(
mut self,
iter: Iter,
) -> Self {
) -> Self
where
<TransTable as TransitionTable>::KeyType: 's,
{
for t in iter {
if self.is_nil() {
break;
Expand All @@ -105,21 +129,36 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
}
self
}
}

impl<TransTable: TransitionTable, SAMRef: Deref<Target = GeneralSAM<TransTable>> + Clone>
GeneralSAMState<TransTable, SAMRef>
{
pub fn get_non_nil_trans(&self, key: &TransTable::KeyType) -> Option<Self> {
self.get_node()
.and_then(|node| node.trans.get(key))
.map(|x| Self {
sam: self.sam.clone(),
node_id: *x,
})
}

fn wrap_travel_along_callback<
's,
TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
ExtraType,
ErrorType,
F: 's
+ FnMut(
TravelEvent<(&GeneralSAMState<TransTable>, &TN), ExtraType, TN::InnerType>,
TravelEvent<(&Self, &TN), ExtraType, TN::InnerType>,
) -> Result<ExtraType, ErrorType>,
>(
&'s self,
mut callback: F,
) -> impl FnMut(
TravelEvent<&TN, (GeneralSAMState<'s, TransTable>, ExtraType), TN::InnerType>,
) -> Result<(GeneralSAMState<'s, TransTable>, ExtraType), ErrorType> {
TravelEvent<&TN, (Self, ExtraType), TN::InnerType>,
) -> Result<(Self, ExtraType), ErrorType>
+ 's {
move |event| match event {
TravelEvent::PushRoot(trie_root) => {
let res = callback(TravelEvent::PushRoot((self, trie_root)))?;
Expand All @@ -143,9 +182,7 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
TN: TrieNodeAlike<InnerType = TransTable::KeyType> + Clone,
ExtraType,
ErrorType,
F: FnMut(
TravelEvent<(&GeneralSAMState<TransTable>, &TN), ErrorType, TN::InnerType>,
) -> Result<ErrorType, ExtraType>,
F: FnMut(TravelEvent<(&Self, &TN), ErrorType, TN::InnerType>) -> Result<ErrorType, ExtraType>,
>(
&self,
trie_node: TN,
Expand All @@ -158,9 +195,7 @@ impl<'s, TransTable: TransitionTable> GeneralSAMState<'s, TransTable> {
TN: TrieNodeAlike<InnerType = TransTable::KeyType>,
ExtraType,
ErrorType,
F: FnMut(
TravelEvent<(&GeneralSAMState<TransTable>, &TN), ErrorType, TN::InnerType>,
) -> Result<ErrorType, ExtraType>,
F: FnMut(TravelEvent<(&Self, &TN), ErrorType, TN::InnerType>) -> Result<ErrorType, ExtraType>,
>(
&self,
trie_node: TN,
Expand Down
31 changes: 29 additions & 2 deletions src/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn test_rope() {

#[cfg(feature = "trie")]
mod trie {
use std::collections::BTreeMap;
use std::{collections::BTreeMap, ops::Deref};

use rand::{
distributions::{Alphanumeric, DistString},
Expand Down Expand Up @@ -129,8 +129,9 @@ mod trie {
T: Clone,
TransTable: TransitionTable<KeyType = T>,
Iter: Iterator<Item = T>,
SAMRef: Deref<Target = GeneralSAM<TransTable>>,
>(
tokenizer: &GreedyTokenizer<TransTable, usize>,
tokenizer: &GreedyTokenizer<TransTable, usize, SAMRef>,
trie: &Trie<TransTable>,
seq: Iter,
) {
Expand Down Expand Up @@ -188,6 +189,32 @@ mod trie {
case_tokenizer(&tokenizer, &trie, "abc".bytes());
}

#[test]
fn test_tokenizer_owning_sam() {
let vocab = [
"a", "ab", "b", "bc", "c", "d", "e", "f", "cd", "abcde", "你好", "🧡",
];
let mut trie = Trie::<BTreeTransTable<u8>>::default();
let mut id_to_word = BTreeMap::new();
for word in vocab {
id_to_word.insert(trie.insert_iter(word.bytes()), word);
}

let sam = GeneralSAM::<BTreeTransTable<u8>>::from_trie(trie.get_root_state());

let tokenizer = GreedyTokenizer::<BTreeTransTable<_>, _, _>::build_from_sam_and_trie(
sam,
trie.get_root_state(),
);

case_tokenizer(&tokenizer, &trie, "abcde".bytes());
case_tokenizer(&tokenizer, &trie, "abcdf".bytes());
case_tokenizer(&tokenizer, &trie, "abca".bytes());
case_tokenizer(&tokenizer, &trie, "Hi,你好吗?".bytes());
case_tokenizer(&tokenizer, &trie, "🧡🧡🧡🧡🧡!".bytes());
case_tokenizer(&tokenizer, &trie, "abc".bytes());
}

fn case_tokenizer_vocab<
T: Clone + Ord + Eq + std::hash::Hash,
TransTable: TransitionTable<KeyType = T>,
Expand Down
73 changes: 54 additions & 19 deletions src/trie.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Trie, supporting `TrieNodeAlike`.

use std::ops::Deref;

use crate::{ConstructiveTransitionTable, GeneralSAMNodeID, TransitionTable, TrieNodeAlike};

pub type TrieNodeID = GeneralSAMNodeID;
Expand All @@ -18,12 +20,23 @@ pub struct Trie<TransTable: TransitionTable> {
node_pool: Vec<TrieNode<TransTable>>,
}

#[derive(Clone, Debug)]
pub struct TrieState<'s, TransTable: TransitionTable> {
pub trie: &'s Trie<TransTable>,
#[derive(Debug)]
pub struct TrieState<TransTable: TransitionTable, TrieRef: Deref<Target = Trie<TransTable>>> {
pub trie: TrieRef,
pub node_id: TrieNodeID,
}

impl<TransTable: TransitionTable, TrieRef: Deref<Target = Trie<TransTable>> + Clone> Clone
for TrieState<TransTable, TrieRef>
{
fn clone(&self) -> Self {
Self {
trie: self.trie.clone(),
node_id: self.node_id,
}
}
}

impl<TransTable: ConstructiveTransitionTable> TrieNode<TransTable> {
fn new(parent: TrieNodeID) -> Self {
Self {
Expand Down Expand Up @@ -70,7 +83,7 @@ impl<TransTable: TransitionTable> Trie<TransTable> {
self.node_pool.len()
}

pub fn get_state(&self, node_id: TrieNodeID) -> TrieState<TransTable> {
pub fn get_state(&self, node_id: TrieNodeID) -> TrieState<TransTable, &Trie<TransTable>> {
if node_id >= self.node_pool.len() {
return TrieState {
trie: self,
Expand All @@ -91,7 +104,7 @@ impl<TransTable: TransitionTable> Trie<TransTable> {
self.get_node(TRIE_ROOT_NODE_ID).unwrap()
}

pub fn get_root_state(&self) -> TrieState<TransTable> {
pub fn get_root_state(&self) -> TrieState<TransTable, &Trie<TransTable>> {
self.get_state(TRIE_ROOT_NODE_ID)
}

Expand Down Expand Up @@ -142,7 +155,16 @@ impl<TransTable: ConstructiveTransitionTable> Trie<TransTable> {
}
}

impl<'s, TransTable: TransitionTable> TrieState<'s, TransTable> {
impl<TransTable: TransitionTable, TrieRef: Deref<Target = Trie<TransTable>>>
TrieState<TransTable, TrieRef>
{
pub fn inner_as_ref(&self) -> TrieState<TransTable, &Trie<TransTable>> {
TrieState {
trie: &self.trie,
node_id: self.node_id,
}
}

pub fn is_nil(&self) -> bool {
self.node_id == TRIE_NIL_NODE_ID
}
Expand All @@ -151,7 +173,7 @@ impl<'s, TransTable: TransitionTable> TrieState<'s, TransTable> {
self.node_id == TRIE_ROOT_NODE_ID
}

pub fn get_node(&self) -> Option<&'s TrieNode<TransTable>> {
pub fn get_node(&self) -> Option<&TrieNode<TransTable>> {
self.trie.get_node(self.node_id)
}

Expand All @@ -170,15 +192,28 @@ impl<'s, TransTable: TransitionTable> TrieState<'s, TransTable> {
self.node_id = TRIE_NIL_NODE_ID;
}
}

pub fn feed_iter<Iter: Iterator<Item = TransTable::KeyType>>(&mut self, iter: Iter) {
iter.for_each(|x| self.goto(&x));
}

pub fn feed_ref_iter<'s, Iter: Iterator<Item = &'s TransTable::KeyType>>(
&'s mut self,
iter: Iter,
) {
iter.for_each(|x| self.goto(x));
}
}

#[derive(Clone, Debug)]
pub struct NextTrieStateIter<'s, TransTable: TransitionTable> {
state: TrieState<'s, TransTable>,
trie: &'s Trie<TransTable>,
iter: TransTable::IterType<'s>,
}

impl<'s, TransTable: TransitionTable> TrieNodeAlike for TrieState<'s, TransTable> {
impl<'s, TransTable: TransitionTable> TrieNodeAlike
for TrieState<TransTable, &'s Trie<TransTable>>
{
type InnerType = TransTable::KeyType;
type NextStateIter = NextTrieStateIter<'s, TransTable>;

Expand All @@ -187,23 +222,23 @@ impl<'s, TransTable: TransitionTable> TrieNodeAlike for TrieState<'s, TransTable
}

fn next_states(self) -> Self::NextStateIter {
let iter = self.get_node().unwrap().trans.iter();
NextTrieStateIter { state: self, iter }
let iter = self.trie.get_node(self.node_id).unwrap().trans.iter();
NextTrieStateIter {
trie: self.trie,
iter,
}
}
}

impl<'s, TransTable: TransitionTable> Iterator for NextTrieStateIter<'s, TransTable> {
type Item = (TransTable::KeyType, TrieState<'s, TransTable>);
type Item = (
TransTable::KeyType,
TrieState<TransTable, &'s Trie<TransTable>>,
);

fn next(&mut self) -> Option<Self::Item> {
self.iter
.next()
.map(|(t, next_node_id)| (t.clone(), self.state.trie.get_state(*next_node_id)))
}
}

impl<'s, TransTable: TransitionTable> NextTrieStateIter<'s, TransTable> {
pub fn get_state(&self) -> &TrieState<TransTable> {
&self.state
.map(|(t, next_node_id)| (t.clone(), self.trie.get_state(*next_node_id)))
}
}
Loading

0 comments on commit eb07841

Please sign in to comment.