Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow passing an iterator into prefix searches, insted of a slice. #36

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ assert_eq!(
);

// common_prefix_search(): Find words which is included in `query`'s prefix.
let results_in_u8s: Vec<Vec<u8>> = trie.common_prefix_search("すしや").collect();
let results_in_str: Vec<String> = trie.common_prefix_search("すしや").collect();
let results_in_u8s: Vec<Vec<u8>> = trie.common_prefix_search("すしや".bytes()).collect();
let results_in_str: Vec<String> = trie.common_prefix_search("すしや".bytes()).collect();
assert_eq!(
results_in_str,
vec![
Expand Down
8 changes: 4 additions & 4 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ mod trie {
// Tested function takes too short compared to build().
// So loop many times.
let results_in_str: Vec<String> =
trie.common_prefix_search("すしをにぎる").collect();
trie.common_prefix_search("すしをにぎる".bytes()).collect();
for _ in 0..(times - 1) {
for entry in trie.common_prefix_search("すしをにぎる") {
for entry in trie.common_prefix_search("すしをにぎる".bytes()) {
black_box::<Vec<u8>>(entry);
}
}
Expand Down Expand Up @@ -249,12 +249,12 @@ mod trie {
// iter_batched() does not properly time `routine` time when `setup` time is far longer than `routine` time.
// Tested function takes too short compared to build(). So loop many times.
let result = trie
.common_prefix_search::<Vec<u8>, _>("すしをにぎる")
.common_prefix_search::<Vec<u8>, _, _>("すしをにぎる".bytes())
.next()
.is_some();
for _ in 0..(times - 1) {
let _ = trie
.common_prefix_search::<Vec<u8>, _>("すしをにぎる")
.common_prefix_search::<Vec<u8>, _, _>("すしをにぎる".bytes())
.next()
.is_some();
}
Expand Down
27 changes: 15 additions & 12 deletions src/iter/prefix_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,28 @@ use crate::try_collect::{TryCollect, TryFromIterator};
use louds_rs::LoudsNodeNum;
use std::marker::PhantomData;

#[derive(Debug, Clone)]
#[derive(Debug)]
/// Iterates through all the common prefixes of a given query.
pub struct PrefixIter<'a, Label, Value, C, M> {
pub struct PrefixIter<'a, Label, I, Value, C, M> {
trie: &'a Trie<Label, Value>,
query: Vec<Label>,
index: usize,
query: I,
node: LoudsNodeNum,
buffer: Vec<&'a Label>,
consume: Option<&'a Value>,
col: PhantomData<(C, M)>,
}

impl<'a, Label: Ord + Clone, Value, C, M> PrefixIter<'a, Label, Value, C, M> {
impl<'a, 'b, Label: Ord + Clone, Value, I: Iterator<Item = Label>, C, M>
PrefixIter<'a, Label, I, Value, C, M>
{
#[inline]
pub(crate) fn new(trie: &'a Trie<Label, Value>, query: impl AsRef<[Label]>) -> Self {
pub(crate) fn new(
trie: &'a Trie<Label, Value>,
query: impl IntoIterator<IntoIter = I>,
) -> Self {
Self {
trie,
query: query.as_ref().to_vec(),
index: 0,
query: query.into_iter(),
node: LoudsNodeNum(1),
buffer: Vec::new(),
consume: None,
Expand All @@ -30,18 +33,19 @@ impl<'a, Label: Ord + Clone, Value, C, M> PrefixIter<'a, Label, Value, C, M> {
}
}

impl<'a, Label: Ord + Clone, Value, C, M> Iterator for PrefixIter<'a, Label, Value, C, M>
impl<'a, 'b, Label: Ord + Clone, Value, I: Iterator<Item = Label>, C, M> Iterator
for PrefixIter<'a, Label, I, Value, C, M>
where
C: TryFromIterator<Label, M>,
{
type Item = (C, &'a Value);
fn next(&mut self) -> Option<Self::Item> {
while self.consume.is_none() {
if let Some(chr) = self.query.get(self.index) {
if let Some(chr) = self.query.next() {
let children_node_nums: Vec<_> = self.trie.children_node_nums(self.node).collect();
let res = self
.trie
.bin_search_by_children_labels(chr, &children_node_nums[..]);
.bin_search_by_children_labels(&chr, &children_node_nums[..]);
match res {
Ok(j) => {
let child_node_num = children_node_nums[j];
Expand All @@ -54,7 +58,6 @@ where
} else {
return None;
}
self.index += 1;
}
if let Some(v) = self.consume.take() {
let col = self.buffer.clone();
Expand Down
23 changes: 14 additions & 9 deletions src/map/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ impl<Label: Ord, Value> Trie<Label, Value> {
///
/// Note: A prefix may be an exact match or not, and an exact match may be a
/// prefix or not.
pub fn is_prefix(&self, query: impl AsRef<[Label]>) -> bool {
pub fn is_prefix(&self, query: impl IntoIterator<Item = Label>) -> bool {
let mut cur_node_num = LoudsNodeNum(1);

for chr in query.as_ref().iter() {
for chr in query {
let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
let res = self.bin_search_by_children_labels(&chr, &children_node_nums[..]);
match res {
Ok(j) => cur_node_num = children_node_nums[j],
Err(_) => return false,
Expand Down Expand Up @@ -128,13 +128,14 @@ impl<Label: Ord, Value> Trie<Label, Value> {
}

/// Return the common prefixes of `query`.
pub fn common_prefix_search<C, M>(
pub fn common_prefix_search<'a, C, M, I>(
&self,
query: impl AsRef<[Label]>,
) -> PrefixIter<'_, Label, Value, C, M>
query: I,
) -> PrefixIter<'_, Label, I::IntoIter, Value, C, M>
where
C: TryFromIterator<Label, M>,
Label: Clone,
I: IntoIterator<Item = Label>,
{
PrefixIter::new(self, query)
}
Expand Down Expand Up @@ -338,7 +339,11 @@ mod search_tests {
assert!(trie.exact_match("").is_none());
let _ = trie.predictive_search::<String, _>("").next();
let _ = trie.postfix_search::<String, _>("").next();
let _ = trie.common_prefix_search::<String, _>("").next();

let trie2 = build_trie2();
let _ = trie2
.common_prefix_search::<String, _, _>(&mut "".chars())
.next();
}

#[test]
Expand Down Expand Up @@ -401,7 +406,7 @@ mod search_tests {
fn $name() {
let (query, expected_match) = $value;
let trie = super::build_trie();
let result = trie.is_prefix(query);
let result = trie.is_prefix(query.bytes());
assert_eq!(result, expected_match);
}
)*
Expand Down Expand Up @@ -488,7 +493,7 @@ mod search_tests {
fn $name() {
let (query, expected_results) = $value;
let trie = super::build_trie();
let results: Vec<(String, &u8)> = trie.common_prefix_search(query).collect();
let results: Vec<(String, &u8)> = trie.common_prefix_search(&mut query.bytes()).collect();
let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
assert_eq!(results, expected_results);
}
Expand Down
32 changes: 23 additions & 9 deletions src/trie/trie_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ impl<Label: Ord> Trie<Label> {
///
/// let trie = Trie::from_iter(["a", "app", "apple", "better", "application"]);
///
/// let results: Vec<String> = trie.common_prefix_search("application").collect();
/// let results: Vec<String> = trie.common_prefix_search("application".bytes()).collect();
///
/// assert_eq!(results, vec!["a", "app", "application"]);
///
/// ```
pub fn common_prefix_search<C, M>(
pub fn common_prefix_search<C, M, I>(
&self,
query: impl AsRef<[Label]>,
) -> Keys<PrefixIter<'_, Label, (), C, M>>
query: I,
) -> Keys<PrefixIter<'_, Label, I::IntoIter, (), C, M>>
where
C: TryFromIterator<Label, M>,
Label: Clone,
I: IntoIterator<Item = Label>,
{
// TODO: We could return Keys iterators instead of collecting.
self.0.common_prefix_search(query).keys()
Expand Down Expand Up @@ -146,7 +147,7 @@ impl<Label: Ord> Trie<Label> {
///
/// Note: A prefix may be an exact match or not, and an exact match may be a
/// prefix or not.
pub fn is_prefix(&self, query: impl AsRef<[Label]>) -> bool {
pub fn is_prefix(&self, query: impl IntoIterator<Item = Label>) -> bool {
self.0.is_prefix(query)
}

Expand Down Expand Up @@ -194,6 +195,17 @@ mod search_tests {
builder.build()
}

fn build_trie2() -> Trie<char> {
let mut builder = TrieBuilder::new();
builder.insert("a".chars());
builder.insert("app".chars());
builder.insert("apple".chars());
builder.insert("better".chars());
builder.insert("application".chars());
builder.insert("アップル🍎".chars());
builder.build()
}

#[test]
fn trie_from_iter() {
let trie = Trie::<u8>::from_iter(["a", "app", "apple", "better", "application"]);
Expand Down Expand Up @@ -240,7 +252,9 @@ mod search_tests {
assert!(!trie.exact_match(""));
let _ = trie.predictive_search::<String, _>("").next();
let _ = trie.postfix_search::<String, _>("").next();
let _ = trie.common_prefix_search::<String, _>("").next();
let _ = trie
.common_prefix_search::<String, _, _>(&mut "".bytes())
.next();
}

#[cfg(feature = "mem_dbg")]
Expand Down Expand Up @@ -322,7 +336,7 @@ mod search_tests {
fn $name() {
let (query, expected_match) = $value;
let trie = super::build_trie();
let result = trie.is_prefix(query);
let result = trie.is_prefix(query.bytes());
assert_eq!(result, expected_match);
}
)*
Expand Down Expand Up @@ -378,8 +392,8 @@ mod search_tests {
#[test]
fn $name() {
let (query, expected_results) = $value;
let trie = super::build_trie();
let results: Vec<String> = trie.common_prefix_search(query).collect();
let trie = super::build_trie2();
let results: Vec<String> = trie.common_prefix_search(&mut query.chars()).collect();
assert_eq!(results, expected_results);
}
)*
Expand Down