Skip to content

Commit

Permalink
Simplify search logic and move it into SearchIndex struct
Browse files Browse the repository at this point in the history
  • Loading branch information
hpeebles committed Sep 26, 2024
1 parent dc82207 commit 95489df
Show file tree
Hide file tree
Showing 16 changed files with 141 additions and 77 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use candid::CandidType;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use ts_export::ts_export;
use types::{ChannelId, MessageMatch, UserId};

Expand All @@ -9,7 +10,7 @@ pub struct Args {
pub channel_id: ChannelId,
pub search_term: String,
pub max_results: u8,
pub users: Option<Vec<UserId>>,
pub users: Option<HashSet<UserId>>,
}

#[ts_export(community, search_channel)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::jobs::import_groups::finalize_group_import;
use crate::lifecycle::{init_env, init_state};
use crate::memory::get_upgrades_memory;
use crate::{read_state, Data};
use crate::{mutate_state, read_state, Data};
use canister_logger::LogEntry;
use canister_tracing_macros::trace;
use community_canister::post_upgrade::Args;
Expand Down Expand Up @@ -37,4 +37,10 @@ fn post_upgrade(args: Args) {
.data
.record_instructions_count(InstructionCountFunctionId::PostUpgrade, now)
});

mutate_state(|state| {
for channel in state.data.channels.iter_mut() {
channel.chat.events.populate_search_index();
}
});
}
11 changes: 4 additions & 7 deletions backend/canisters/community/impl/src/queries/search_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@ fn search_channel_impl(args: Args, state: &RuntimeState) -> Response {

if let Some(member) = state.data.members.get(caller) {
if let Some(channel) = state.data.channels.get(&args.channel_id) {
match channel.chat.search(
member.user_id,
args.search_term,
args.users,
args.max_results,
state.env.now(),
) {
match channel
.chat
.search(member.user_id, args.search_term, args.users, args.max_results)
{
SearchResults::Success(matches) => Success(SuccessResult { matches }),
SearchResults::InvalidTerm => InvalidTerm,
SearchResults::TermTooLong(v) => TermTooLong(v),
Expand Down
3 changes: 2 additions & 1 deletion backend/canisters/group/api/src/queries/search_messages.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use candid::CandidType;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use ts_export::ts_export;
use types::{MessageMatch, UserId};

Expand All @@ -8,7 +9,7 @@ use types::{MessageMatch, UserId};
pub struct Args {
pub search_term: String,
pub max_results: u8,
pub users: Option<Vec<UserId>>,
pub users: Option<HashSet<UserId>>,
}

#[ts_export(group, search_messages)]
Expand Down
6 changes: 5 additions & 1 deletion backend/canisters/group/impl/src/lifecycle/post_upgrade.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::lifecycle::{init_env, init_state};
use crate::memory::get_upgrades_memory;
use crate::{read_state, Data};
use crate::{mutate_state, read_state, Data};
use canister_logger::LogEntry;
use canister_tracing_macros::trace;
use group_canister::post_upgrade::Args;
Expand Down Expand Up @@ -30,4 +30,8 @@ fn post_upgrade(args: Args) {
.data
.record_instructions_count(InstructionCountFunctionId::PostUpgrade, now)
});

mutate_state(|state| {
state.data.chat.events.populate_search_index();
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn search_messages_impl(args: Args, state: &RuntimeState) -> Response {
match state
.data
.chat
.search(user_id, args.search_term, args.users, args.max_results, state.env.now())
.search(user_id, args.search_term, args.users, args.max_results)
{
SearchResults::Success(matches) => Success(SuccessResult { matches }),
SearchResults::InvalidTerm => InvalidTerm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl PublicGroups {
.map(|c| {
let score = if let Some(query) = &query {
let document: Document = c.into();
document.calculate_score(query)
document.calculate_score(&query)
} else if c.hotness_score > 0 {
c.hotness_score
} else {
Expand Down
8 changes: 7 additions & 1 deletion backend/canisters/user/impl/src/lifecycle/post_upgrade.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::lifecycle::{init_env, init_state};
use crate::memory::get_upgrades_memory;
use crate::Data;
use crate::{mutate_state, Data};
use canister_logger::LogEntry;
use canister_tracing_macros::trace;
use ic_cdk::post_upgrade;
Expand All @@ -22,4 +22,10 @@ fn post_upgrade(args: Args) {
init_state(env, data, args.wasm_version);

info!(version = %args.wasm_version, "Post-upgrade complete");

mutate_state(|state| {
for chat in state.data.direct_chats.iter_mut() {
chat.events.populate_search_index();
}
});
}
5 changes: 3 additions & 2 deletions backend/canisters/user/impl/src/queries/search_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::guards::caller_is_owner;
use crate::{read_state, RuntimeState};
use canister_api_macros::query;
use search::Query;
use types::EventIndex;
use std::collections::HashSet;
use types::MessageIndex;
use user_canister::search_messages::{Response::*, *};

const MIN_TERM_LENGTH: u8 = 3;
Expand Down Expand Up @@ -35,7 +36,7 @@ fn search_messages_impl(args: Args, state: &RuntimeState) -> Response {
let matches =
direct_chat
.events
.search_messages(state.env.now(), EventIndex::default(), &query, args.max_results, my_user_id);
.search_messages(MessageIndex::default(), query, HashSet::new(), args.max_results, my_user_id);

Success(SuccessResult { matches })
}
72 changes: 45 additions & 27 deletions backend/libraries/chat_events/src/chat_events.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::expiring_events::ExpiringEvents;
use crate::last_updated_timestamps::LastUpdatedTimestamps;
use crate::search_index::SearchIndex;
use crate::*;
use candid::Principal;
use event_store_producer::{EventBuilder, EventStoreClient, Runtime};
Expand All @@ -11,7 +12,7 @@ use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::cmp::{max, Reverse};
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use types::{
AcceptP2PSwapResult, CallParticipant, CancelP2PSwapResult, CanisterId, Chat, CompleteP2PSwapResult,
CompletedCryptoTransaction, Cryptocurrency, DirectChatCreated, EventIndex, EventWrapper, EventsTimeToLiveUpdated,
Expand Down Expand Up @@ -40,9 +41,22 @@ pub struct ChatEvents {
last_updated_timestamps: LastUpdatedTimestamps,
video_call_in_progress: Timestamped<Option<VideoCall>>,
anonymized_id: String,
#[serde(default)]
search_index: SearchIndex,
}

impl ChatEvents {
// TODO remove this
pub fn populate_search_index(&mut self) {
for event in self.main.iter(None, true, EventIndex::default()) {
if let EventOrExpiredRangeInternal::Event(e) = event {
if let ChatEventInternal::Message(m) = &e.event {
self.search_index.push(m.message_index, m.sender, Document::from(&m.content));
}
}
}
}

pub fn new_direct_chat(
them: UserId,
events_ttl: Option<Milliseconds>,
Expand All @@ -61,6 +75,7 @@ impl ChatEvents {
last_updated_timestamps: LastUpdatedTimestamps::default(),
video_call_in_progress: Timestamped::default(),
anonymized_id: hex::encode(anonymized_id.to_be_bytes()),
search_index: SearchIndex::default(),
};

events.push_event(None, ChatEventInternal::DirectChatCreated(DirectChatCreated {}), 0, now);
Expand Down Expand Up @@ -89,6 +104,7 @@ impl ChatEvents {
last_updated_timestamps: LastUpdatedTimestamps::default(),
video_call_in_progress: Timestamped::default(),
anonymized_id: hex::encode(anonymized_id.to_be_bytes()),
search_index: SearchIndex::default(),
};

events.push_event(
Expand Down Expand Up @@ -249,6 +265,10 @@ impl ChatEvents {
let already_edited = message.last_edited.is_some();
message.last_edited = Some(args.now);

let message_index = message.message_index;
let sender = message.sender;
let document = Document::from(&message.content);

if let Some(client) = event_store_client {
let new_length = message.content.text_length();
let payload = MessageEditedEventPayload {
Expand All @@ -270,6 +290,10 @@ impl ChatEvents {
)
}

if args.thread_root_message_index.is_none() {
self.search_index.push(message_index, sender, document);
}

add_to_metrics(
&mut self.metrics,
&mut self.per_user_metrics,
Expand Down Expand Up @@ -416,10 +440,15 @@ impl ChatEvents {
let (message, _) = self.message_internal_mut(EventIndex::default(), thread_root_message_index, message_id.into())?;

let deleted_by = message.deleted_by.clone()?;

let content = std::mem::replace(&mut message.content, MessageContentInternal::Deleted(deleted_by));
let sender = message.sender;

Some((content, message.sender))
if thread_root_message_index.is_none() {
let message_index = message.message_index;
self.search_index.remove(message_index);
}

Some((content, sender))
}

pub fn register_poll_vote(&mut self, args: RegisterPollVoteArgs) -> RegisterPollVoteResult {
Expand Down Expand Up @@ -1272,6 +1301,9 @@ impl ChatEvents {
let events_list = if let Some(root_message_index) = thread_root_message_index {
self.threads.get_mut(&root_message_index).unwrap()
} else {
if let ChatEventInternal::Message(m) = &event {
self.search_index.push(m.message_index, m.sender, Document::from(&m.content));
}
&mut self.main
};

Expand Down Expand Up @@ -1308,37 +1340,23 @@ impl ChatEvents {

pub fn search_messages(
&self,
now: TimestampMillis,
min_visible_event_index: EventIndex,
query: &Query,
min_visible_message_index: MessageIndex,
query: Query,
users: HashSet<UserId>,
max_results: u8,
my_user_id: UserId,
) -> Vec<MessageMatch> {
self.visible_main_events_reader(min_visible_event_index)
.iter(None, true)
.filter_map(|e| e.as_event())
.filter_map(|e| e.event.as_message().filter(|m| m.deleted_by.is_none()).map(|m| (e, m)))
.filter(|(_, m)| if query.users.is_empty() { true } else { query.users.contains(&m.sender) })
.filter_map(|(e, m)| {
if query.tokens.is_empty() {
Some((1, m))
} else {
let mut document: Document = (&m.content).into();
document.set_age(now - e.timestamp);
match document.calculate_score(query) {
0 => None,
n => Some((n, m)),
}
}
})
.sorted_unstable_by_key(|(score, _)| *score)
.rev()
let reader = self.main_events_reader();
self.search_index
.search_messages(min_visible_message_index, query, users)
.filter_map(|m| reader.message_internal(m.into()))
.filter(|m| m.deleted_by.is_none())
.take(max_results as usize)
.map(|(score, message)| MessageMatch {
.map(|message| MessageMatch {
message_index: message.message_index,
sender: message.sender,
content: message.content.hydrate(Some(my_user_id)),
score,
score: 1,
})
.collect()
}
Expand Down
1 change: 1 addition & 0 deletions backend/libraries/chat_events/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod events_map;
mod expiring_events;
mod last_updated_timestamps;
mod message_content_internal;
mod search_index;

pub use crate::chat_event_internal::*;
pub use crate::chat_events::*;
Expand Down
32 changes: 32 additions & 0 deletions backend/libraries/chat_events/src/search_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use search::{Document, Query};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashSet};
use types::{MessageIndex, UserId};

#[derive(Serialize, Deserialize, Default)]
pub struct SearchIndex {
map: BTreeMap<MessageIndex, (UserId, Document)>,
}

impl SearchIndex {
pub fn push(&mut self, message_index: MessageIndex, sender: UserId, document: Document) {
self.map.insert(message_index, (sender, document));
}

pub fn remove(&mut self, message_index: MessageIndex) {
self.map.remove(&message_index);
}

pub fn search_messages(
&self,
min_visible_message_index: MessageIndex,
query: Query,
users: HashSet<UserId>,
) -> impl Iterator<Item = MessageIndex> + '_ {
self.map
.range(min_visible_message_index..)
.rev()
.filter(move |(_, (sender, doc))| users.contains(sender) || doc.is_match(&query))
.map(|(id, _)| *id)
}
}
8 changes: 3 additions & 5 deletions backend/libraries/group_chat_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,8 @@ impl GroupChatCore {
&self,
user_id: UserId,
search_term: String,
users: Option<Vec<UserId>>,
users: Option<HashSet<UserId>>,
max_results: u8,
now: TimestampMillis,
) -> SearchResults {
use SearchResults::*;

Expand Down Expand Up @@ -534,12 +533,11 @@ impl GroupChatCore {
Some(p) => p,
};

let mut query = Query::parse(search_term);
query.users = HashSet::from_iter(users);
let query = Query::parse(search_term);

let matches = self
.events
.search_messages(now, member.min_visible_event_index(), &query, max_results, user_id);
.search_messages(member.min_visible_message_index(), query, users, max_results, user_id);

Success(matches)
}
Expand Down
1 change: 1 addition & 0 deletions backend/libraries/search/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ edition = "2021"

[dependencies]
types = { path = "../types" }
serde = { workspace = true, features = ["derive"] }
Loading

0 comments on commit 95489df

Please sign in to comment.