From 003c5d34e43d14bbc3cd0ee7d8d315998a78f8bb Mon Sep 17 00:00:00 2001 From: Hamish Peebles Date: Mon, 9 Dec 2024 11:54:58 +0000 Subject: [PATCH] Add `Key` and `KeyPrefix` traits to simplify adding additional key types (#7013) --- backend/canisters/community/impl/src/lib.rs | 13 +- .../impl/src/model/events/stable_memory.rs | 4 +- .../impl/src/model/members/stable_memory.rs | 13 +- .../impl/src/updates/delete_channel.rs | 8 +- backend/canisters/group/impl/src/lib.rs | 6 +- backend/canisters/user/impl/src/lib.rs | 4 +- .../impl/src/updates/delete_direct_chat.rs | 17 +- .../chat_events/src/stable_memory/mod.rs | 36 +- .../src/members/stable_memory.rs | 20 +- .../libraries/stable_memory_map/src/key.rs | 555 ------------------ .../libraries/stable_memory_map/src/keys.rs | 115 ++++ .../stable_memory_map/src/keys/chat_event.rs | 250 ++++++++ .../src/keys/community_event.rs | 66 +++ .../stable_memory_map/src/keys/macros.rs | 49 ++ .../stable_memory_map/src/keys/member.rs | 140 +++++ .../libraries/stable_memory_map/src/lib.rs | 82 ++- 16 files changed, 743 insertions(+), 635 deletions(-) delete mode 100644 backend/libraries/stable_memory_map/src/key.rs create mode 100644 backend/libraries/stable_memory_map/src/keys.rs create mode 100644 backend/libraries/stable_memory_map/src/keys/chat_event.rs create mode 100644 backend/libraries/stable_memory_map/src/keys/community_event.rs create mode 100644 backend/libraries/stable_memory_map/src/keys/macros.rs create mode 100644 backend/libraries/stable_memory_map/src/keys/member.rs diff --git a/backend/canisters/community/impl/src/lib.rs b/backend/canisters/community/impl/src/lib.rs index 2da2153541..a8f50a4026 100644 --- a/backend/canisters/community/impl/src/lib.rs +++ b/backend/canisters/community/impl/src/lib.rs @@ -30,7 +30,7 @@ use rand::rngs::StdRng; use rand::RngCore; use serde::{Deserialize, Deserializer, Serialize}; use serde_bytes::ByteBuf; -use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix}; +use stable_memory_map::{BaseKeyPrefix, ChatEventKeyPrefix}; use std::cell::RefCell; use std::collections::BTreeMap; use std::ops::Deref; @@ -243,12 +243,9 @@ impl RuntimeState { } final_prize_payments.extend(result.final_prize_payments); for thread in result.threads { - self.data - .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_channel( - channel.id, - Some(thread.root_message_index), - ))); + self.data.stable_memory_keys_to_garbage_collect.push(BaseKeyPrefix::from( + ChatEventKeyPrefix::new_from_channel(channel.id, Some(thread.root_message_index)), + )); } } jobs::garbage_collect_stable_memory::start_job_if_required(self); @@ -379,7 +376,7 @@ struct Data { expiring_member_actions: ExpiringMemberActions, user_cache: UserCache, user_event_sync_queue: GroupedTimerJobQueue, - stable_memory_keys_to_garbage_collect: Vec, + stable_memory_keys_to_garbage_collect: Vec, members_migrated_to_stable_memory: bool, #[serde(default)] bots: GroupBots, diff --git a/backend/canisters/community/impl/src/model/events/stable_memory.rs b/backend/canisters/community/impl/src/model/events/stable_memory.rs index dffc8363d3..63f11b4b46 100644 --- a/backend/canisters/community/impl/src/model/events/stable_memory.rs +++ b/backend/canisters/community/impl/src/model/events/stable_memory.rs @@ -1,7 +1,7 @@ use crate::model::events::CommunityEventInternal; use candid::Deserialize; use serde::Serialize; -use stable_memory_map::{with_map_mut, CommunityEventKeyPrefix}; +use stable_memory_map::{with_map_mut, CommunityEventKeyPrefix, KeyPrefix}; use types::EventWrapperInternal; #[derive(Serialize, Deserialize)] @@ -19,7 +19,7 @@ impl Default for EventsStableStorage { impl EventsStableStorage { pub fn insert(&mut self, event: EventWrapperInternal) { - with_map_mut(|m| m.insert(self.prefix.create_key(event.index).into(), event_to_bytes(event))); + with_map_mut(|m| m.insert(self.prefix.create_key(&event.index), event_to_bytes(event))); } } diff --git a/backend/canisters/community/impl/src/model/members/stable_memory.rs b/backend/canisters/community/impl/src/model/members/stable_memory.rs index a5a1b81e46..b50a746ed9 100644 --- a/backend/canisters/community/impl/src/model/members/stable_memory.rs +++ b/backend/canisters/community/impl/src/model/members/stable_memory.rs @@ -1,7 +1,7 @@ use crate::CommunityMemberInternal; use candid::Deserialize; use serde::Serialize; -use stable_memory_map::{with_map, with_map_mut, MemberKeyPrefix}; +use stable_memory_map::{with_map, with_map_mut, KeyPrefix, MemberKeyPrefix}; use std::collections::BTreeSet; use types::{is_default, CommunityRole, TimestampMillis, Timestamped, UserId, UserType, Version}; @@ -19,18 +19,18 @@ impl MembersStableStorage { pub fn get(&self, user_id: &UserId) -> Option { with_map(|m| { - m.get(&self.prefix.create_key(*user_id).into()) + m.get(self.prefix.create_key(user_id)) .map(|v| bytes_to_member(&v).hydrate(*user_id)) }) } pub fn insert(&mut self, member: CommunityMemberInternal) { - with_map_mut(|m| m.insert(self.prefix.create_key(member.user_id).into(), member_to_bytes(member.into()))); + with_map_mut(|m| m.insert(self.prefix.create_key(&member.user_id), member_to_bytes(member.into()))); } pub fn remove(&mut self, user_id: &UserId) -> Option { with_map_mut(|m| { - m.remove(&self.prefix.create_key(*user_id).into()) + m.remove(self.prefix.create_key(user_id)) .map(|v| bytes_to_member(&v).hydrate(*user_id)) }) } @@ -38,11 +38,10 @@ impl MembersStableStorage { #[cfg(test)] pub fn all_members(&self) -> Vec { use candid::Principal; - use stable_memory_map::{Key, MemberKey}; + use stable_memory_map::Key; with_map(|m| { - m.range(Key::from(self.prefix.create_key(Principal::from_slice(&[]).into()))..) - .map_while(|(k, v)| MemberKey::try_from(k).ok().map(|k| (k, v))) + m.range(self.prefix.create_key(&Principal::from_slice(&[]).into())..) .take_while(|(k, _)| k.matches_prefix(&self.prefix)) .map(|(k, v)| bytes_to_member(&v).hydrate(k.user_id())) .collect() diff --git a/backend/canisters/community/impl/src/updates/delete_channel.rs b/backend/canisters/community/impl/src/updates/delete_channel.rs index b2fd650439..2eb69b34da 100644 --- a/backend/canisters/community/impl/src/updates/delete_channel.rs +++ b/backend/canisters/community/impl/src/updates/delete_channel.rs @@ -5,7 +5,7 @@ use crate::{ use canister_api_macros::update; use canister_tracing_macros::trace; use community_canister::delete_channel::{Response::*, *}; -use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix, MemberKeyPrefix}; +use stable_memory_map::{BaseKeyPrefix, ChatEventKeyPrefix, MemberKeyPrefix}; use types::{ChannelDeleted, ChannelId}; #[update(msgpack = true)] @@ -54,13 +54,13 @@ fn delete_channel_impl(channel_id: ChannelId, state: &mut RuntimeState) -> Respo state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_channel(channel_id, None))); + .push(BaseKeyPrefix::from(ChatEventKeyPrefix::new_from_channel(channel_id, None))); for message_index in channel.chat.events.thread_keys() { state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_channel( + .push(BaseKeyPrefix::from(ChatEventKeyPrefix::new_from_channel( channel_id, Some(message_index), ))); @@ -69,7 +69,7 @@ fn delete_channel_impl(channel_id: ChannelId, state: &mut RuntimeState) -> Respo state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(MemberKeyPrefix::new_from_channel(channel_id))); + .push(BaseKeyPrefix::from(MemberKeyPrefix::new_from_channel(channel_id))); crate::jobs::garbage_collect_stable_memory::start_job_if_required(state); diff --git a/backend/canisters/group/impl/src/lib.rs b/backend/canisters/group/impl/src/lib.rs index b68807881b..3d17bfdd32 100644 --- a/backend/canisters/group/impl/src/lib.rs +++ b/backend/canisters/group/impl/src/lib.rs @@ -24,7 +24,7 @@ use msgpack::serialize_then_unwrap; use notifications_canister::c2c_push_notification; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; -use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix}; +use stable_memory_map::{BaseKeyPrefix, ChatEventKeyPrefix}; use std::cell::RefCell; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::{BTreeMap, HashMap, HashSet}; @@ -359,7 +359,7 @@ impl RuntimeState { for thread in result.threads { self.data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_group_chat(Some( + .push(BaseKeyPrefix::from(ChatEventKeyPrefix::new_from_group_chat(Some( thread.root_message_index, )))); } @@ -470,7 +470,7 @@ struct Data { expiring_member_actions: ExpiringMemberActions, user_cache: UserCache, user_event_sync_queue: GroupedTimerJobQueue, - stable_memory_keys_to_garbage_collect: Vec, + stable_memory_keys_to_garbage_collect: Vec, } fn init_instruction_counts_log() -> InstructionCountsLog { diff --git a/backend/canisters/user/impl/src/lib.rs b/backend/canisters/user/impl/src/lib.rs index 711b7cf5b5..1b5da52747 100644 --- a/backend/canisters/user/impl/src/lib.rs +++ b/backend/canisters/user/impl/src/lib.rs @@ -25,7 +25,7 @@ use model::streak::Streak; use notifications_canister::c2c_push_notification; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; -use stable_memory_map::KeyPrefix; +use stable_memory_map::BaseKeyPrefix; use std::cell::RefCell; use std::collections::{BTreeMap, HashSet}; use std::ops::Deref; @@ -254,7 +254,7 @@ struct Data { pub referred_by: Option, pub referrals: Referrals, pub message_activity_events: MessageActivityEvents, - pub stable_memory_keys_to_garbage_collect: Vec, + pub stable_memory_keys_to_garbage_collect: Vec, } impl Data { diff --git a/backend/canisters/user/impl/src/updates/delete_direct_chat.rs b/backend/canisters/user/impl/src/updates/delete_direct_chat.rs index 28095fc371..4ce1f5913d 100644 --- a/backend/canisters/user/impl/src/updates/delete_direct_chat.rs +++ b/backend/canisters/user/impl/src/updates/delete_direct_chat.rs @@ -2,7 +2,7 @@ use crate::guards::caller_is_owner; use crate::{mutate_state, run_regular_jobs, RuntimeState}; use canister_api_macros::update; use canister_tracing_macros::trace; -use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix}; +use stable_memory_map::{BaseKeyPrefix, ChatEventKeyPrefix}; use user_canister::delete_direct_chat::{Response::*, *}; #[update(guard = "caller_is_owner", msgpack = true)] @@ -23,16 +23,15 @@ fn delete_direct_chat_impl(args: Args, state: &mut RuntimeState) -> Response { state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_direct_chat(args.user_id, None))); + .push(BaseKeyPrefix::from(ChatEventKeyPrefix::new_from_direct_chat( + args.user_id, + None, + ))); for message_index in chat.events.thread_keys() { - state - .data - .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_direct_chat( - args.user_id, - Some(message_index), - ))); + state.data.stable_memory_keys_to_garbage_collect.push(BaseKeyPrefix::from( + ChatEventKeyPrefix::new_from_direct_chat(args.user_id, Some(message_index)), + )); } crate::jobs::garbage_collect_stable_memory::start_job_if_required(state); diff --git a/backend/libraries/chat_events/src/stable_memory/mod.rs b/backend/libraries/chat_events/src/stable_memory/mod.rs index 6c2fca19e3..e47f52d2e9 100644 --- a/backend/libraries/chat_events/src/stable_memory/mod.rs +++ b/backend/libraries/chat_events/src/stable_memory/mod.rs @@ -1,7 +1,7 @@ use crate::{ChatEventInternal, EventsMap}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; -use stable_memory_map::{with_map, with_map_mut, ChatEventKey, ChatEventKeyPrefix, Key}; +use stable_memory_map::{with_map, with_map_mut, ChatEventKey, ChatEventKeyPrefix, KeyPrefix}; use std::cmp::min; use std::collections::VecDeque; use std::ops::RangeBounds; @@ -15,16 +15,15 @@ mod tests; // Used to efficiently read all events from stable memory when migrating a group into a community pub fn read_events_as_bytes(chat: Chat, after: Option, max_bytes: usize) -> Vec<(EventContext, ByteBuf)> { let key = match after { - None => ChatEventKeyPrefix::new_from_chat(chat, None).create_key(EventIndex::default()), + None => ChatEventKeyPrefix::new_from_chat(chat, None).create_key(&EventIndex::default()), Some(EventContext { thread_root_message_index, event_index, - }) => ChatEventKeyPrefix::new_from_chat(chat, thread_root_message_index).create_key(event_index.incr()), + }) => ChatEventKeyPrefix::new_from_chat(chat, thread_root_message_index).create_key(&event_index.incr()), }; with_map(|m| { let mut total_bytes = 0; - m.range(Key::from(key)..) - .map_while(|(k, v)| ChatEventKey::try_from(k).ok().map(|k| (k, v))) + m.range(key..) .take_while(|(k, v)| { if !k.matches_chat(&chat) { return false; @@ -42,11 +41,11 @@ pub fn write_events_as_bytes(chat: Chat, events: Vec<(EventContext, ByteBuf)>) { with_map_mut(|m| { for (context, bytes) in events { let prefix = ChatEventKeyPrefix::new_from_chat(chat, context.thread_root_message_index); - let key = prefix.create_key(context.event_index); + let key = prefix.create_key(&context.event_index); let value = bytes.into_vec(); // Check the event is valid. We could remove this once we're more confident let _ = bytes_to_event(&value); - m.insert(key.into(), value); + m.insert(key, value); } }); } @@ -83,11 +82,6 @@ impl ChatEventsStableStorage { }; Iter::new(prefix, start, end) } - - fn get_internal(&self, event_index: EventIndex) -> Option> { - let key = self.prefix.create_key(event_index); - with_map(|m| m.get(&Key::from(key))) - } } impl EventsMap for ChatEventsStableStorage { @@ -96,15 +90,15 @@ impl EventsMap for ChatEventsStableStorage { } fn get(&self, event_index: EventIndex) -> Option> { - self.get_internal(event_index).map(|v| bytes_to_event(&v)) + with_map(|m| m.get(self.prefix.create_key(&event_index))).map(|v| bytes_to_event(&v)) } fn insert(&mut self, event: EventWrapperInternal) { - with_map_mut(|m| m.insert(Key::from(self.prefix.create_key(event.index)), event_to_bytes(event))); + with_map_mut(|m| m.insert(self.prefix.create_key(&event.index), event_to_bytes(event))); } fn remove(&mut self, event_index: EventIndex) -> Option> { - with_map_mut(|m| m.remove(&Key::from(self.prefix.create_key(event_index)))).map(|v| bytes_to_event(&v)) + with_map_mut(|m| m.remove(self.prefix.create_key(&event_index))).map(|v| bytes_to_event(&v)) } fn range>( @@ -183,11 +177,11 @@ impl Iter { } fn next_key(&self) -> ChatEventKey { - self.prefix.create_key(self.next) + self.prefix.create_key(&self.next) } fn next_back_key(&self) -> ChatEventKey { - self.prefix.create_key(self.next_back) + self.prefix.create_key(&self.next_back) } fn check_buffer_direction(&mut self, forward: bool) { @@ -227,8 +221,8 @@ impl Iterator for Iter { self.check_buffer_direction(true); if self.buffer.is_empty() { self.buffer = with_map(|m| { - m.range(Key::from(self.next_key())..=Key::from(self.next_back_key())) - .map_while(|(k, v)| ChatEventKey::try_from(k).ok().map(|k| (k.event_index(), v))) + m.range(self.next_key()..=self.next_back_key()) + .map(|(k, v)| (k.event_index(), v)) .take(self.next_buffer_size) .collect() }); @@ -252,9 +246,9 @@ impl DoubleEndedIterator for Iter { self.check_buffer_direction(false); if self.buffer.is_empty() { self.buffer = with_map(|m| { - m.range(Key::from(self.next_key())..=Key::from(self.next_back_key())) + m.range(self.next_key()..=self.next_back_key()) .rev() - .map_while(|(k, v)| ChatEventKey::try_from(k).ok().map(|k| (k.event_index(), v))) + .map(|(k, v)| (k.event_index(), v)) .take(self.next_buffer_size) .collect() }); diff --git a/backend/libraries/group_chat_core/src/members/stable_memory.rs b/backend/libraries/group_chat_core/src/members/stable_memory.rs index ad62ec1026..559a4dd651 100644 --- a/backend/libraries/group_chat_core/src/members/stable_memory.rs +++ b/backend/libraries/group_chat_core/src/members/stable_memory.rs @@ -2,7 +2,7 @@ use crate::{GroupMemberInternal, GroupMemberStableStorage}; use candid::{Deserialize, Principal}; use serde::Serialize; use serde_bytes::ByteBuf; -use stable_memory_map::{with_map, with_map_mut, Key, MemberKey, MemberKeyPrefix}; +use stable_memory_map::{with_map, with_map_mut, Key, KeyPrefix, MemberKeyPrefix}; use types::{MultiUserChat, UserId}; #[derive(Serialize, Deserialize)] @@ -21,18 +21,18 @@ impl MembersStableStorage { pub fn get(&self, user_id: &UserId) -> Option { with_map(|m| { - m.get(&self.prefix.create_key(*user_id).into()) + m.get(self.prefix.create_key(user_id)) .map(|v| bytes_to_member(&v).hydrate(*user_id)) }) } pub fn insert(&mut self, member: GroupMemberInternal) { - with_map_mut(|m| m.insert(self.prefix.create_key(member.user_id).into(), member_to_bytes(member.into()))); + with_map_mut(|m| m.insert(self.prefix.create_key(&member.user_id), member_to_bytes(member.into()))); } pub fn remove(&mut self, user_id: &UserId) -> Option { with_map_mut(|m| { - m.remove(&self.prefix.create_key(*user_id).into()) + m.remove(self.prefix.create_key(user_id)) .map(|v| bytes_to_member(&v).hydrate(*user_id)) }) } @@ -44,14 +44,13 @@ impl MembersStableStorage { // Used to efficiently read all members from stable memory when migrating a group into a community pub fn read_members_as_bytes(&self, after: Option, max_bytes: usize) -> Vec<(UserId, ByteBuf)> { let start_key = match after { - None => self.prefix.create_key(Principal::from_slice(&[]).into()), - Some(user_id) => self.prefix.create_key(user_id), + None => self.prefix.create_key(&Principal::from_slice(&[]).into()), + Some(user_id) => self.prefix.create_key(&user_id), }; with_map(|m| { let mut total_bytes = 0; - m.range(Key::from(start_key.clone())..) - .map_while(|(k, v)| MemberKey::try_from(k).ok().map(|k| (k, v))) + m.range(start_key.clone()..) .skip_while(|(k, _)| *k == start_key) .take_while(|(k, v)| { if !k.matches_prefix(&self.prefix) { @@ -68,8 +67,7 @@ impl MembersStableStorage { #[cfg(test)] pub fn all_members(&self) -> Vec { with_map(|m| { - m.range(Key::from(self.prefix.create_key(Principal::from_slice(&[]).into()))..) - .map_while(|(k, v)| MemberKey::try_from(k).ok().map(|k| (k, v))) + m.range(self.prefix.create_key(&Principal::from_slice(&[]).into())..) .take_while(|(k, _)| k.matches_prefix(&self.prefix)) .map(|(k, v)| bytes_to_member(&v).hydrate(k.user_id())) .collect() @@ -87,7 +85,7 @@ pub fn write_members_from_bytes(chat: MultiUserChat, members: Vec<(UserId, ByteB // Check that the bytes are valid let _ = bytes_to_member(&bytes); latest = Some(user_id); - m.insert(prefix.create_key(user_id).into(), bytes); + m.insert(prefix.create_key(&user_id), bytes); } }); latest diff --git a/backend/libraries/stable_memory_map/src/key.rs b/backend/libraries/stable_memory_map/src/key.rs deleted file mode 100644 index f458b1d060..0000000000 --- a/backend/libraries/stable_memory_map/src/key.rs +++ /dev/null @@ -1,555 +0,0 @@ -use ic_principal::Principal; -use ic_stable_structures::storable::Bound; -use ic_stable_structures::Storable; -use serde::{Deserialize, Serialize}; -use std::borrow::Cow; -use types::{ChannelId, Chat, EventIndex, MessageIndex, MultiUserChat, UserId}; - -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] -#[serde(transparent)] -pub struct Key(#[serde(with = "serde_bytes")] Vec); - -impl Key { - pub fn starts_with(&self, prefix: &[u8]) -> bool { - self.0.starts_with(prefix) - } - - pub fn as_slice(&self) -> &[u8] { - &self.0 - } -} - -impl Storable for Key { - fn to_bytes(&self) -> Cow<[u8]> { - Cow::Borrowed(&self.0) - } - - fn from_bytes(bytes: Cow<[u8]>) -> Self { - Key(bytes.to_vec()) - } - - const BOUND: Bound = Bound::Unbounded; -} - -impl From for Key { - fn from(value: KeyPrefix) -> Self { - Key(value.0) - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] -#[serde(transparent)] -pub struct KeyPrefix(#[serde(with = "serde_bytes")] Vec); - -impl KeyPrefix { - pub fn as_slice(&self) -> &[u8] { - &self.0 - } -} - -macro_rules! key { - ($key_name:ident, $key_prefix_name:ident, $key_types:pat) => { - #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] - #[serde(into = "Key", try_from = "Key")] - pub struct $key_name(Vec); - - #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] - #[serde(into = "KeyPrefix", try_from = "KeyPrefix")] - pub struct $key_prefix_name(Vec); - - impl From<$key_name> for Key { - fn from(value: $key_name) -> Self { - Key(value.0) - } - } - - impl From<$key_prefix_name> for KeyPrefix { - fn from(value: $key_prefix_name) -> Self { - KeyPrefix(value.0) - } - } - - impl TryFrom for $key_name { - type Error = String; - - fn try_from(value: Key) -> Result { - validate_key(&value.0, |kt| matches!(kt, $key_types)).map(|_| $key_name(value.0)) - } - } - - impl TryFrom for $key_prefix_name { - type Error = String; - - fn try_from(value: KeyPrefix) -> Result { - validate_key(&value.0, |kt| matches!(kt, $key_types)).map(|_| $key_prefix_name(value.0)) - } - } - - impl $key_name { - pub fn matches_prefix(&self, prefix: &$key_prefix_name) -> bool { - self.0.starts_with(&prefix.0) - } - } - }; -} - -key!( - ChatEventKey, - ChatEventKeyPrefix, - KeyType::DirectChatEvent - | KeyType::GroupChatEvent - | KeyType::ChannelEvent - | KeyType::DirectChatThreadEvent - | KeyType::GroupChatThreadEvent - | KeyType::ChannelThreadEvent -); - -key!( - MemberKey, - MemberKeyPrefix, - KeyType::GroupMember | KeyType::ChannelMember | KeyType::CommunityMember -); - -key!(CommunityEventKey, CommunityEventKeyPrefix, KeyType::CommunityEvent); - -fn validate_key bool>(key: &[u8], validator: F) -> Result<(), String> { - if extract_key_type(key).is_some_and(validator) { - Ok(()) - } else { - Err(format!("Key type mismatch: {:?}", key.first())) - } -} - -impl ChatEventKeyPrefix { - pub fn new_from_chat(chat: Chat, thread_root_message_index: Option) -> Self { - match chat { - Chat::Direct(user_id) => Self::new_from_direct_chat(Principal::from(user_id).into(), thread_root_message_index), - Chat::Group(_) => Self::new_from_group_chat(thread_root_message_index), - Chat::Channel(_, channel_id) => Self::new_from_channel(channel_id, thread_root_message_index), - } - } - - pub fn new_from_direct_chat(user_id: UserId, thread_root_message_index: Option) -> Self { - // We don't actually need the userId length marker but existing entries have it so we - // need to keep it to be backwards compatible. If we decide we want to remove it then we - // can switch to a new KeyType for direct chat events, but that is quite a lot of work - // and complexity since we'd still have to support the old version too (or migrate them). - - let user_id_bytes = user_id.as_slice(); - - match thread_root_message_index { - None => { - // KeyType::DirectChatThreadEvent 1 byte - // UserId length 1 byte - // UserId bytes UserId length bytes - let mut bytes = Vec::with_capacity(user_id_bytes.len() + 2); - bytes.push(KeyType::DirectChatEvent as u8); - bytes.push(user_id_bytes.len() as u8); - bytes.extend_from_slice(user_id_bytes); - ChatEventKeyPrefix(bytes) - } - Some(root_message_index) => { - // KeyType::DirectChatThreadEvent 1 byte - // UserId length 1 byte - // UserId bytes UserId length bytes - // Thread root message index 4 bytes - let mut bytes = Vec::with_capacity(user_id_bytes.len() + 6); - bytes.push(KeyType::DirectChatThreadEvent as u8); - bytes.push(user_id_bytes.len() as u8); - bytes.extend_from_slice(user_id_bytes); - bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); - ChatEventKeyPrefix(bytes) - } - } - } - - pub fn new_from_group_chat(thread_root_message_index: Option) -> Self { - match thread_root_message_index { - None => { - // KeyType::GroupChatEvent 1 byte - ChatEventKeyPrefix(vec![KeyType::GroupChatEvent as u8]) - } - Some(root_message_index) => { - // KeyType::GroupChatThreadEvent 1 byte - // Thread root message index 4 bytes - let mut bytes = Vec::with_capacity(5); - bytes.push(KeyType::GroupChatThreadEvent as u8); - bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); - ChatEventKeyPrefix(bytes) - } - } - } - - pub fn new_from_channel(channel_id: ChannelId, thread_root_message_index: Option) -> Self { - match thread_root_message_index { - None => { - // KeyType::ChannelEvent 1 byte - // ChannelId 4 bytes - let mut bytes = Vec::with_capacity(5); - bytes.push(KeyType::ChannelEvent as u8); - bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); - ChatEventKeyPrefix(bytes) - } - Some(root_message_index) => { - // KeyType::ChannelThreadEvent 1 byte - // ChannelId 4 bytes - // Thread root message index 4 bytes - let mut bytes = Vec::with_capacity(9); - bytes.push(KeyType::ChannelThreadEvent as u8); - bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); - bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); - ChatEventKeyPrefix(bytes) - } - } - } - - pub fn create_key(&self, event_index: EventIndex) -> ChatEventKey { - let mut bytes = Vec::with_capacity(self.0.len() + 4); - bytes.extend_from_slice(self.0.as_slice()); - bytes.extend_from_slice(&u32::from(event_index).to_be_bytes()); - ChatEventKey(bytes) - } -} - -impl ChatEventKey { - pub fn matches_chat(&self, chat: &Chat) -> bool { - match (chat, self.key_type()) { - (Chat::Direct(id), KeyType::DirectChatEvent | KeyType::DirectChatThreadEvent) => { - let user_id_len = self.0[1] as usize; - let user_id_start = 2; - let user_id_end = user_id_start + user_id_len; - let user_id = Principal::from_slice(&self.0[user_id_start..user_id_end]).into(); - *id == user_id - } - (Chat::Group(_), KeyType::GroupChatEvent | KeyType::GroupChatThreadEvent) => true, - (Chat::Channel(_, id), KeyType::ChannelEvent | KeyType::ChannelThreadEvent) => { - let channel_id = u32::from_be_bytes(self.0[1..5].try_into().unwrap()).into(); - *id == channel_id - } - _ => false, - } - } - - pub fn thread_root_message_index(&self) -> Option { - if matches!( - self.key_type(), - KeyType::DirectChatThreadEvent | KeyType::GroupChatThreadEvent | KeyType::ChannelThreadEvent - ) { - let start = self.0.len() - 8; - let end = start + 4; - Some(u32::from_be_bytes(self.0[start..end].try_into().unwrap()).into()) - } else { - None - } - } - - pub fn event_index(&self) -> EventIndex { - let start = self.0.len() - 4; - u32::from_be_bytes(self.0[start..].try_into().unwrap()).into() - } - - fn key_type(&self) -> KeyType { - extract_key_type(&self.0).unwrap() - } -} - -impl MemberKeyPrefix { - pub fn new_from_chat(chat: MultiUserChat) -> Self { - match chat { - MultiUserChat::Group(_) => Self::new_from_group(), - MultiUserChat::Channel(_, channel_id) => Self::new_from_channel(channel_id), - } - } - - pub fn new_from_group() -> Self { - // KeyType::GroupMember 1 byte - MemberKeyPrefix(vec![KeyType::GroupMember as u8]) - } - - pub fn new_from_channel(channel_id: ChannelId) -> Self { - // KeyType::ChannelMember 1 byte - // ChannelId 4 bytes - let mut bytes = Vec::with_capacity(5); - bytes.push(KeyType::ChannelMember as u8); - bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); - MemberKeyPrefix(bytes) - } - - pub fn new_from_community() -> Self { - // KeyType::CommunityMember 1 byte - MemberKeyPrefix(vec![KeyType::CommunityMember as u8]) - } - - pub fn create_key(&self, user_id: UserId) -> MemberKey { - let user_id_bytes = user_id.as_slice(); - let mut bytes = Vec::with_capacity(self.0.len() + user_id_bytes.len()); - bytes.extend_from_slice(self.0.as_slice()); - bytes.extend_from_slice(user_id_bytes); - MemberKey(bytes) - } -} - -impl MemberKey { - pub fn user_id(&self) -> UserId { - let prefix_len = match self.key_type() { - KeyType::GroupMember | KeyType::CommunityMember => 1, - KeyType::ChannelMember => 5, - _ => unreachable!(), - }; - Principal::from_slice(&self.0[prefix_len..]).into() - } - - fn key_type(&self) -> KeyType { - KeyType::try_from(self.0[0]).unwrap() - } -} - -impl CommunityEventKeyPrefix { - pub fn new() -> Self { - // KeyType::CommunityEvent 1 byte - CommunityEventKeyPrefix(vec![KeyType::CommunityEvent as u8]) - } - - pub fn create_key(&self, event_index: EventIndex) -> CommunityEventKey { - let mut bytes = Vec::with_capacity(5); - bytes.extend_from_slice(self.0.as_slice()); - bytes.extend_from_slice(&u32::from(event_index).to_be_bytes()); - CommunityEventKey(bytes) - } -} - -impl Default for CommunityEventKeyPrefix { - fn default() -> Self { - Self::new() - } -} - -impl CommunityEventKey { - pub fn event_index(&self) -> EventIndex { - let start = self.0.len() - 4; - u32::from_be_bytes(self.0[start..].try_into().unwrap()).into() - } -} - -#[derive(Copy, Clone, Eq, PartialEq)] -#[repr(u8)] -pub enum KeyType { - DirectChatEvent = 1, - GroupChatEvent = 2, - ChannelEvent = 3, - DirectChatThreadEvent = 4, - GroupChatThreadEvent = 5, - ChannelThreadEvent = 6, - GroupMember = 7, - ChannelMember = 8, - CommunityMember = 9, - CommunityEvent = 10, -} - -fn extract_key_type(bytes: &[u8]) -> Option { - bytes.first().and_then(|b| KeyType::try_from(*b).ok()) -} - -impl TryFrom for KeyType { - type Error = (); - - fn try_from(value: u8) -> Result { - match value { - 1 => Ok(KeyType::DirectChatEvent), - 2 => Ok(KeyType::GroupChatEvent), - 3 => Ok(KeyType::ChannelEvent), - 4 => Ok(KeyType::DirectChatThreadEvent), - 5 => Ok(KeyType::GroupChatThreadEvent), - 6 => Ok(KeyType::ChannelThreadEvent), - 7 => Ok(KeyType::GroupMember), - 8 => Ok(KeyType::ChannelMember), - 9 => Ok(KeyType::CommunityMember), - 10 => Ok(KeyType::CommunityEvent), - _ => Err(()), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ChatEventKey, ChatEventKeyPrefix}; - use rand::{thread_rng, Rng, RngCore}; - use types::{ChannelId, Chat, EventIndex, MessageIndex}; - - #[test] - fn direct_chat_event_key_e2e() { - for thread in [false, true] { - for _ in 0..100 { - let user_id_bytes: [u8; 10] = thread_rng().gen(); - let user_id = Principal::from_slice(&user_id_bytes); - let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); - let prefix = ChatEventKeyPrefix::new_from_direct_chat(user_id.into(), thread_root_message_index); - let event_index = EventIndex::from(thread_rng().next_u32()); - let key = Key::from(prefix.create_key(event_index)); - let event_key = ChatEventKey::try_from(key.clone()).unwrap(); - - assert_eq!( - *event_key.0.first().unwrap(), - if thread { KeyType::DirectChatThreadEvent } else { KeyType::DirectChatEvent } as u8 - ); - assert_eq!(event_key.0.len(), if thread { 20 } else { 16 }); - assert!(event_key.matches_prefix(&prefix)); - assert!(event_key.matches_chat(&Chat::Direct(user_id.into()))); - assert_eq!(event_key.event_index(), event_index); - - let serialized = msgpack::serialize_then_unwrap(&event_key); - assert_eq!(serialized.len(), event_key.0.len() + 2); - let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, event_key); - assert_eq!(deserialized.0, key.0); - } - } - } - - #[test] - fn group_chat_event_key_e2e() { - for thread in [false, true] { - for _ in 0..100 { - let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); - let prefix = ChatEventKeyPrefix::new_from_group_chat(thread_root_message_index); - let event_index = EventIndex::from(thread_rng().next_u32()); - let key = Key::from(prefix.create_key(event_index)); - let event_key = ChatEventKey::try_from(key.clone()).unwrap(); - - assert_eq!( - *event_key.0.first().unwrap(), - if thread { KeyType::GroupChatThreadEvent } else { KeyType::GroupChatEvent } as u8 - ); - assert_eq!(event_key.0.len(), if thread { 9 } else { 5 }); - assert!(event_key.matches_prefix(&prefix)); - assert!(event_key.matches_chat(&Chat::Group(Principal::anonymous().into()))); - assert_eq!(event_key.event_index(), event_index); - assert_eq!(event_key.thread_root_message_index(), thread_root_message_index); - - let serialized = msgpack::serialize_then_unwrap(&event_key); - assert_eq!(serialized.len(), event_key.0.len() + 2); - let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, event_key); - assert_eq!(deserialized.0, key.0); - } - } - } - - #[test] - fn channel_event_key_e2e() { - for thread in [false, true] { - for _ in 0..100 { - let channel_id = ChannelId::from(thread_rng().next_u32()); - let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); - let prefix = ChatEventKeyPrefix::new_from_channel(channel_id, thread_root_message_index); - let event_index = EventIndex::from(thread_rng().next_u32()); - let key = Key::from(prefix.create_key(event_index)); - let event_key = ChatEventKey::try_from(key.clone()).unwrap(); - - assert_eq!( - *event_key.0.first().unwrap(), - if thread { KeyType::ChannelThreadEvent } else { KeyType::ChannelEvent } as u8 - ); - assert_eq!(event_key.0.len(), if thread { 13 } else { 9 }); - assert!(event_key.matches_prefix(&prefix)); - assert!(event_key.matches_chat(&Chat::Channel(Principal::anonymous().into(), channel_id))); - assert_eq!(event_key.event_index(), event_index); - - let serialized = msgpack::serialize_then_unwrap(&event_key); - assert_eq!(serialized.len(), event_key.0.len() + 2); - let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, event_key); - assert_eq!(deserialized.0, key.0); - } - } - } - - #[test] - fn group_chat_member_key_e2e() { - for _ in 0..100 { - let user_id_bytes: [u8; 10] = thread_rng().gen(); - let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); - let prefix = MemberKeyPrefix::new_from_group(); - let key = Key::from(prefix.create_key(user_id)); - let member_key = MemberKey::try_from(key.clone()).unwrap(); - - assert_eq!(*member_key.0.first().unwrap(), KeyType::GroupMember as u8); - assert_eq!(member_key.0.len(), 11); - assert!(member_key.matches_prefix(&prefix)); - assert_eq!(member_key.user_id(), user_id); - - let serialized = msgpack::serialize_then_unwrap(&member_key); - assert_eq!(serialized.len(), member_key.0.len() + 2); - let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, member_key); - assert_eq!(deserialized.0, key.0); - } - } - - #[test] - fn channel_member_key_e2e() { - for _ in 0..100 { - let channel_id = ChannelId::from(thread_rng().next_u32()); - let user_id_bytes: [u8; 10] = thread_rng().gen(); - let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); - let prefix = MemberKeyPrefix::new_from_channel(channel_id); - let key = Key::from(prefix.create_key(user_id)); - let member_key = MemberKey::try_from(key.clone()).unwrap(); - - assert_eq!(*member_key.0.first().unwrap(), KeyType::ChannelMember as u8); - assert_eq!(member_key.0.len(), 15); - assert!(member_key.matches_prefix(&prefix)); - assert_eq!(member_key.user_id(), user_id); - - let serialized = msgpack::serialize_then_unwrap(&member_key); - assert_eq!(serialized.len(), member_key.0.len() + 2); - let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, member_key); - assert_eq!(deserialized.0, key.0); - } - } - - #[test] - fn community_member_key_e2e() { - for _ in 0..100 { - let user_id_bytes: [u8; 10] = thread_rng().gen(); - let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); - let prefix = MemberKeyPrefix::new_from_community(); - let key = Key::from(prefix.create_key(user_id)); - let member_key = MemberKey::try_from(key.clone()).unwrap(); - - assert_eq!(*member_key.0.first().unwrap(), KeyType::CommunityMember as u8); - assert_eq!(member_key.0.len(), 11); - assert!(member_key.matches_prefix(&prefix)); - assert_eq!(member_key.user_id(), user_id); - - let serialized = msgpack::serialize_then_unwrap(&member_key); - assert_eq!(serialized.len(), member_key.0.len() + 2); - let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, member_key); - assert_eq!(deserialized.0, key.0); - } - } - - #[test] - fn community_event_key_e2e() { - for _ in 0..100 { - let prefix = CommunityEventKeyPrefix::new(); - let event_index = EventIndex::from(thread_rng().next_u32()); - let key = Key::from(prefix.create_key(event_index)); - let event_key = CommunityEventKey::try_from(key.clone()).unwrap(); - - assert_eq!(*event_key.0.first().unwrap(), KeyType::CommunityEvent as u8); - assert_eq!(event_key.0.len(), 5); - assert!(event_key.matches_prefix(&prefix)); - assert_eq!(event_key.event_index(), event_index); - - let serialized = msgpack::serialize_then_unwrap(&event_key); - assert_eq!(serialized.len(), event_key.0.len() + 2); - let deserialized: CommunityEventKey = msgpack::deserialize_then_unwrap(&serialized); - assert_eq!(deserialized, event_key); - assert_eq!(deserialized.0, key.0); - } - } -} diff --git a/backend/libraries/stable_memory_map/src/keys.rs b/backend/libraries/stable_memory_map/src/keys.rs new file mode 100644 index 0000000000..e976f0e5ed --- /dev/null +++ b/backend/libraries/stable_memory_map/src/keys.rs @@ -0,0 +1,115 @@ +use ic_stable_structures::storable::Bound; +use ic_stable_structures::Storable; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; + +mod chat_event; +mod community_event; +mod macros; +mod member; + +pub use chat_event::*; +pub use community_event::*; +pub use member::*; + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(transparent)] +pub struct BaseKey(#[serde(with = "serde_bytes")] Vec); + +impl BaseKey { + pub fn starts_with(&self, prefix: &BaseKeyPrefix) -> bool { + self.0.starts_with(prefix.0.as_slice()) + } + + pub fn as_slice(&self) -> &[u8] { + self.0.as_slice() + } +} + +impl Storable for BaseKey { + fn to_bytes(&self) -> Cow<[u8]> { + Cow::Borrowed(&self.0) + } + + fn from_bytes(bytes: Cow<[u8]>) -> Self { + BaseKey(bytes.to_vec()) + } + + const BOUND: Bound = Bound::Unbounded; +} + +impl From for BaseKey { + fn from(value: BaseKeyPrefix) -> Self { + BaseKey(value.0) + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(transparent)] +pub struct BaseKeyPrefix(#[serde(with = "serde_bytes")] Vec); + +impl BaseKeyPrefix { + pub fn as_slice(&self) -> &[u8] { + self.0.as_slice() + } +} + +pub trait Key: Into + TryFrom + Clone { + type Prefix: KeyPrefix; + + fn matches_prefix(&self, key: &Self::Prefix) -> bool; +} + +pub trait KeyPrefix: Into + TryFrom + Clone { + type Key; + type Suffix; + + fn create_key(&self, value: &Self::Suffix) -> Self::Key; +} + +fn validate_key bool>(key: &[u8], validator: F) -> Result<(), String> { + if extract_key_type(key).is_some_and(validator) { + Ok(()) + } else { + Err(format!("Key type mismatch: {:?}", key.first())) + } +} + +#[derive(Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +pub enum KeyType { + DirectChatEvent = 1, + GroupChatEvent = 2, + ChannelEvent = 3, + DirectChatThreadEvent = 4, + GroupChatThreadEvent = 5, + ChannelThreadEvent = 6, + GroupMember = 7, + ChannelMember = 8, + CommunityMember = 9, + CommunityEvent = 10, +} + +fn extract_key_type(bytes: &[u8]) -> Option { + bytes.first().and_then(|b| KeyType::try_from(*b).ok()) +} + +impl TryFrom for KeyType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(KeyType::DirectChatEvent), + 2 => Ok(KeyType::GroupChatEvent), + 3 => Ok(KeyType::ChannelEvent), + 4 => Ok(KeyType::DirectChatThreadEvent), + 5 => Ok(KeyType::GroupChatThreadEvent), + 6 => Ok(KeyType::ChannelThreadEvent), + 7 => Ok(KeyType::GroupMember), + 8 => Ok(KeyType::ChannelMember), + 9 => Ok(KeyType::CommunityMember), + 10 => Ok(KeyType::CommunityEvent), + _ => Err(()), + } + } +} diff --git a/backend/libraries/stable_memory_map/src/keys/chat_event.rs b/backend/libraries/stable_memory_map/src/keys/chat_event.rs new file mode 100644 index 0000000000..6acca7947b --- /dev/null +++ b/backend/libraries/stable_memory_map/src/keys/chat_event.rs @@ -0,0 +1,250 @@ +use crate::keys::extract_key_type; +use crate::keys::macros::key; +use crate::{BaseKey, KeyPrefix, KeyType}; +use ic_principal::Principal; +use types::{ChannelId, Chat, EventIndex, MessageIndex, UserId}; + +key!( + ChatEventKey, + ChatEventKeyPrefix, + KeyType::DirectChatEvent + | KeyType::GroupChatEvent + | KeyType::ChannelEvent + | KeyType::DirectChatThreadEvent + | KeyType::GroupChatThreadEvent + | KeyType::ChannelThreadEvent +); + +impl ChatEventKeyPrefix { + pub fn new_from_chat(chat: Chat, thread_root_message_index: Option) -> Self { + match chat { + Chat::Direct(user_id) => Self::new_from_direct_chat(Principal::from(user_id).into(), thread_root_message_index), + Chat::Group(_) => Self::new_from_group_chat(thread_root_message_index), + Chat::Channel(_, channel_id) => Self::new_from_channel(channel_id, thread_root_message_index), + } + } + + pub fn new_from_direct_chat(user_id: UserId, thread_root_message_index: Option) -> Self { + // We don't actually need the userId length marker but existing entries have it so we + // need to keep it to be backwards compatible. If we decide we want to remove it then we + // can switch to a new KeyType for direct chat events, but that is quite a lot of work + // and complexity since we'd still have to support the old version too (or migrate them). + + let user_id_bytes = user_id.as_slice(); + + match thread_root_message_index { + None => { + // KeyType::DirectChatThreadEvent 1 byte + // UserId length 1 byte + // UserId bytes UserId length bytes + let mut bytes = Vec::with_capacity(user_id_bytes.len() + 2); + bytes.push(KeyType::DirectChatEvent as u8); + bytes.push(user_id_bytes.len() as u8); + bytes.extend_from_slice(user_id_bytes); + ChatEventKeyPrefix(bytes) + } + Some(root_message_index) => { + // KeyType::DirectChatThreadEvent 1 byte + // UserId length 1 byte + // UserId bytes UserId length bytes + // Thread root message index 4 bytes + let mut bytes = Vec::with_capacity(user_id_bytes.len() + 6); + bytes.push(KeyType::DirectChatThreadEvent as u8); + bytes.push(user_id_bytes.len() as u8); + bytes.extend_from_slice(user_id_bytes); + bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + } + } + + pub fn new_from_group_chat(thread_root_message_index: Option) -> Self { + match thread_root_message_index { + None => { + // KeyType::GroupChatEvent 1 byte + ChatEventKeyPrefix(vec![KeyType::GroupChatEvent as u8]) + } + Some(root_message_index) => { + // KeyType::GroupChatThreadEvent 1 byte + // Thread root message index 4 bytes + let mut bytes = Vec::with_capacity(5); + bytes.push(KeyType::GroupChatThreadEvent as u8); + bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + } + } + + pub fn new_from_channel(channel_id: ChannelId, thread_root_message_index: Option) -> Self { + match thread_root_message_index { + None => { + // KeyType::ChannelEvent 1 byte + // ChannelId 4 bytes + let mut bytes = Vec::with_capacity(5); + bytes.push(KeyType::ChannelEvent as u8); + bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + Some(root_message_index) => { + // KeyType::ChannelThreadEvent 1 byte + // ChannelId 4 bytes + // Thread root message index 4 bytes + let mut bytes = Vec::with_capacity(9); + bytes.push(KeyType::ChannelThreadEvent as u8); + bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); + bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + } + } +} + +impl KeyPrefix for ChatEventKeyPrefix { + type Key = ChatEventKey; + type Suffix = EventIndex; + + fn create_key(&self, event_index: &EventIndex) -> ChatEventKey { + let mut bytes = Vec::with_capacity(self.0.len() + 4); + bytes.extend_from_slice(self.0.as_slice()); + bytes.extend_from_slice(&u32::from(*event_index).to_be_bytes()); + ChatEventKey(bytes) + } +} + +impl ChatEventKey { + pub fn matches_chat(&self, chat: &Chat) -> bool { + match (chat, self.key_type()) { + (Chat::Direct(id), KeyType::DirectChatEvent | KeyType::DirectChatThreadEvent) => { + let user_id_len = self.0[1] as usize; + let user_id_start = 2; + let user_id_end = user_id_start + user_id_len; + let user_id = Principal::from_slice(&self.0[user_id_start..user_id_end]).into(); + *id == user_id + } + (Chat::Group(_), KeyType::GroupChatEvent | KeyType::GroupChatThreadEvent) => true, + (Chat::Channel(_, id), KeyType::ChannelEvent | KeyType::ChannelThreadEvent) => { + let channel_id = u32::from_be_bytes(self.0[1..5].try_into().unwrap()).into(); + *id == channel_id + } + _ => false, + } + } + + pub fn thread_root_message_index(&self) -> Option { + if matches!( + self.key_type(), + KeyType::DirectChatThreadEvent | KeyType::GroupChatThreadEvent | KeyType::ChannelThreadEvent + ) { + let start = self.0.len() - 8; + let end = start + 4; + Some(u32::from_be_bytes(self.0[start..end].try_into().unwrap()).into()) + } else { + None + } + } + + pub fn event_index(&self) -> EventIndex { + let start = self.0.len() - 4; + u32::from_be_bytes(self.0[start..].try_into().unwrap()).into() + } + + fn key_type(&self) -> KeyType { + extract_key_type(&self.0).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Key; + use rand::{thread_rng, Rng, RngCore}; + use types::{ChannelId, Chat, EventIndex, MessageIndex}; + + #[test] + fn direct_chat_event_key_e2e() { + for thread in [false, true] { + for _ in 0..100 { + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = Principal::from_slice(&user_id_bytes); + let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); + let prefix = ChatEventKeyPrefix::new_from_direct_chat(user_id.into(), thread_root_message_index); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = BaseKey::from(prefix.create_key(&event_index)); + let event_key = ChatEventKey::try_from(key.clone()).unwrap(); + + assert_eq!( + *event_key.0.first().unwrap(), + if thread { KeyType::DirectChatThreadEvent } else { KeyType::DirectChatEvent } as u8 + ); + assert_eq!(event_key.0.len(), if thread { 20 } else { 16 }); + assert!(event_key.matches_prefix(&prefix)); + assert!(event_key.matches_chat(&Chat::Direct(user_id.into()))); + assert_eq!(event_key.event_index(), event_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } + } + + #[test] + fn group_chat_event_key_e2e() { + for thread in [false, true] { + for _ in 0..100 { + let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); + let prefix = ChatEventKeyPrefix::new_from_group_chat(thread_root_message_index); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = BaseKey::from(prefix.create_key(&event_index)); + let event_key = ChatEventKey::try_from(key.clone()).unwrap(); + + assert_eq!( + *event_key.0.first().unwrap(), + if thread { KeyType::GroupChatThreadEvent } else { KeyType::GroupChatEvent } as u8 + ); + assert_eq!(event_key.0.len(), if thread { 9 } else { 5 }); + assert!(event_key.matches_prefix(&prefix)); + assert!(event_key.matches_chat(&Chat::Group(Principal::anonymous().into()))); + assert_eq!(event_key.event_index(), event_index); + assert_eq!(event_key.thread_root_message_index(), thread_root_message_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } + } + + #[test] + fn channel_event_key_e2e() { + for thread in [false, true] { + for _ in 0..100 { + let channel_id = ChannelId::from(thread_rng().next_u32()); + let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); + let prefix = ChatEventKeyPrefix::new_from_channel(channel_id, thread_root_message_index); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = BaseKey::from(prefix.create_key(&event_index)); + let event_key = ChatEventKey::try_from(key.clone()).unwrap(); + + assert_eq!( + *event_key.0.first().unwrap(), + if thread { KeyType::ChannelThreadEvent } else { KeyType::ChannelEvent } as u8 + ); + assert_eq!(event_key.0.len(), if thread { 13 } else { 9 }); + assert!(event_key.matches_prefix(&prefix)); + assert!(event_key.matches_chat(&Chat::Channel(Principal::anonymous().into(), channel_id))); + assert_eq!(event_key.event_index(), event_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } + } +} diff --git a/backend/libraries/stable_memory_map/src/keys/community_event.rs b/backend/libraries/stable_memory_map/src/keys/community_event.rs new file mode 100644 index 0000000000..239d6349d6 --- /dev/null +++ b/backend/libraries/stable_memory_map/src/keys/community_event.rs @@ -0,0 +1,66 @@ +use crate::keys::macros::key; +use crate::{BaseKey, KeyPrefix, KeyType}; +use types::EventIndex; + +key!(CommunityEventKey, CommunityEventKeyPrefix, KeyType::CommunityEvent); + +impl CommunityEventKeyPrefix { + pub fn new() -> Self { + // KeyType::CommunityEvent 1 byte + CommunityEventKeyPrefix(vec![KeyType::CommunityEvent as u8]) + } +} + +impl KeyPrefix for CommunityEventKeyPrefix { + type Key = CommunityEventKey; + type Suffix = EventIndex; + + fn create_key(&self, event_index: &EventIndex) -> CommunityEventKey { + let mut bytes = Vec::with_capacity(5); + bytes.extend_from_slice(self.0.as_slice()); + bytes.extend_from_slice(&u32::from(*event_index).to_be_bytes()); + CommunityEventKey(bytes) + } +} + +impl Default for CommunityEventKeyPrefix { + fn default() -> Self { + Self::new() + } +} + +impl CommunityEventKey { + pub fn event_index(&self) -> EventIndex { + let start = self.0.len() - 4; + u32::from_be_bytes(self.0[start..].try_into().unwrap()).into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Key; + use rand::{thread_rng, RngCore}; + use types::EventIndex; + + #[test] + fn community_event_key_e2e() { + for _ in 0..100 { + let prefix = CommunityEventKeyPrefix::new(); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = BaseKey::from(prefix.create_key(&event_index)); + let event_key = CommunityEventKey::try_from(key.clone()).unwrap(); + + assert_eq!(*event_key.0.first().unwrap(), KeyType::CommunityEvent as u8); + assert_eq!(event_key.0.len(), 5); + assert!(event_key.matches_prefix(&prefix)); + assert_eq!(event_key.event_index(), event_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: CommunityEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } +} diff --git a/backend/libraries/stable_memory_map/src/keys/macros.rs b/backend/libraries/stable_memory_map/src/keys/macros.rs new file mode 100644 index 0000000000..93ad47c635 --- /dev/null +++ b/backend/libraries/stable_memory_map/src/keys/macros.rs @@ -0,0 +1,49 @@ +macro_rules! key { + ($key_name:ident, $key_prefix_name:ident, $key_types:pat) => { + #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] + #[serde(into = "crate::keys::BaseKey", try_from = "crate::keys::BaseKey")] + pub struct $key_name(Vec); + + #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] + #[serde(into = "crate::keys::BaseKeyPrefix", try_from = "crate::keys::BaseKeyPrefix")] + pub struct $key_prefix_name(Vec); + + impl From<$key_name> for BaseKey { + fn from(value: $key_name) -> Self { + crate::keys::BaseKey(value.0) + } + } + + impl From<$key_prefix_name> for crate::keys::BaseKeyPrefix { + fn from(value: $key_prefix_name) -> Self { + crate::keys::BaseKeyPrefix(value.0) + } + } + + impl TryFrom for $key_name { + type Error = String; + + fn try_from(value: crate::BaseKey) -> Result { + crate::keys::validate_key(&value.0, |kt| matches!(kt, $key_types)).map(|_| $key_name(value.0)) + } + } + + impl TryFrom for $key_prefix_name { + type Error = String; + + fn try_from(value: crate::BaseKeyPrefix) -> Result { + crate::keys::validate_key(&value.0, |kt| matches!(kt, $key_types)).map(|_| $key_prefix_name(value.0)) + } + } + + impl crate::keys::Key for $key_name { + type Prefix = $key_prefix_name; + + fn matches_prefix(&self, prefix: &$key_prefix_name) -> bool { + self.0.starts_with(&prefix.0) + } + } + }; +} + +pub(crate) use key; diff --git a/backend/libraries/stable_memory_map/src/keys/member.rs b/backend/libraries/stable_memory_map/src/keys/member.rs new file mode 100644 index 0000000000..22d0941605 --- /dev/null +++ b/backend/libraries/stable_memory_map/src/keys/member.rs @@ -0,0 +1,140 @@ +use crate::keys::macros::key; +use crate::{BaseKey, KeyPrefix, KeyType}; +use ic_principal::Principal; +use types::{ChannelId, MultiUserChat, UserId}; + +key!( + MemberKey, + MemberKeyPrefix, + KeyType::GroupMember | KeyType::ChannelMember | KeyType::CommunityMember +); + +impl MemberKeyPrefix { + pub fn new_from_chat(chat: MultiUserChat) -> Self { + match chat { + MultiUserChat::Group(_) => Self::new_from_group(), + MultiUserChat::Channel(_, channel_id) => Self::new_from_channel(channel_id), + } + } + + pub fn new_from_group() -> Self { + // KeyType::GroupMember 1 byte + MemberKeyPrefix(vec![KeyType::GroupMember as u8]) + } + + pub fn new_from_channel(channel_id: ChannelId) -> Self { + // KeyType::ChannelMember 1 byte + // ChannelId 4 bytes + let mut bytes = Vec::with_capacity(5); + bytes.push(KeyType::ChannelMember as u8); + bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); + MemberKeyPrefix(bytes) + } + + pub fn new_from_community() -> Self { + // KeyType::CommunityMember 1 byte + MemberKeyPrefix(vec![KeyType::CommunityMember as u8]) + } +} + +impl KeyPrefix for MemberKeyPrefix { + type Key = MemberKey; + type Suffix = UserId; + + fn create_key(&self, user_id: &UserId) -> Self::Key { + let user_id_bytes = user_id.as_slice(); + let mut bytes = Vec::with_capacity(self.0.len() + user_id_bytes.len()); + bytes.extend_from_slice(self.0.as_slice()); + bytes.extend_from_slice(user_id_bytes); + MemberKey(bytes) + } +} + +impl MemberKey { + pub fn user_id(&self) -> UserId { + let prefix_len = match self.key_type() { + KeyType::GroupMember | KeyType::CommunityMember => 1, + KeyType::ChannelMember => 5, + _ => unreachable!(), + }; + Principal::from_slice(&self.0[prefix_len..]).into() + } + + fn key_type(&self) -> KeyType { + KeyType::try_from(self.0[0]).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Key; + use rand::{thread_rng, Rng, RngCore}; + + #[test] + fn group_chat_member_key_e2e() { + for _ in 0..100 { + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); + let prefix = MemberKeyPrefix::new_from_group(); + let key = BaseKey::from(prefix.create_key(&user_id)); + let member_key = MemberKey::try_from(key.clone()).unwrap(); + + assert_eq!(*member_key.0.first().unwrap(), KeyType::GroupMember as u8); + assert_eq!(member_key.0.len(), 11); + assert!(member_key.matches_prefix(&prefix)); + assert_eq!(member_key.user_id(), user_id); + + let serialized = msgpack::serialize_then_unwrap(&member_key); + assert_eq!(serialized.len(), member_key.0.len() + 2); + let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, member_key); + assert_eq!(deserialized.0, key.0); + } + } + + #[test] + fn channel_member_key_e2e() { + for _ in 0..100 { + let channel_id = ChannelId::from(thread_rng().next_u32()); + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); + let prefix = MemberKeyPrefix::new_from_channel(channel_id); + let key = BaseKey::from(prefix.create_key(&user_id)); + let member_key = MemberKey::try_from(key.clone()).unwrap(); + + assert_eq!(*member_key.0.first().unwrap(), KeyType::ChannelMember as u8); + assert_eq!(member_key.0.len(), 15); + assert!(member_key.matches_prefix(&prefix)); + assert_eq!(member_key.user_id(), user_id); + + let serialized = msgpack::serialize_then_unwrap(&member_key); + assert_eq!(serialized.len(), member_key.0.len() + 2); + let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, member_key); + assert_eq!(deserialized.0, key.0); + } + } + + #[test] + fn community_member_key_e2e() { + for _ in 0..100 { + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); + let prefix = MemberKeyPrefix::new_from_community(); + let key = BaseKey::from(prefix.create_key(&user_id)); + let member_key = MemberKey::try_from(key.clone()).unwrap(); + + assert_eq!(*member_key.0.first().unwrap(), KeyType::CommunityMember as u8); + assert_eq!(member_key.0.len(), 11); + assert!(member_key.matches_prefix(&prefix)); + assert_eq!(member_key.user_id(), user_id); + + let serialized = msgpack::serialize_then_unwrap(&member_key); + assert_eq!(serialized.len(), member_key.0.len() + 2); + let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, member_key); + assert_eq!(deserialized.0, key.0); + } + } +} diff --git a/backend/libraries/stable_memory_map/src/lib.rs b/backend/libraries/stable_memory_map/src/lib.rs index 904e50c60d..1c83e12626 100644 --- a/backend/libraries/stable_memory_map/src/lib.rs +++ b/backend/libraries/stable_memory_map/src/lib.rs @@ -5,15 +5,17 @@ use ic_stable_structures::memory_manager::VirtualMemory; use ic_stable_structures::{DefaultMemoryImpl, StableBTreeMap}; use std::cell::RefCell; +use std::marker::PhantomData; +use std::ops::{Bound, RangeBounds}; -mod key; +mod keys; -pub use key::*; +pub use keys::*; pub type Memory = VirtualMemory; -struct StableMemoryMap { - map: StableBTreeMap, Memory>, +pub struct StableMemoryMap { + map: StableBTreeMap, Memory>, } thread_local! { @@ -26,24 +28,47 @@ pub fn init(memory: Memory) { })); } -pub fn with_map, Memory>) -> R, R>(f: F) -> R { - MAP.with_borrow(|m| f(&m.as_ref().unwrap().map)) +pub fn with_map R, R>(f: F) -> R { + MAP.with_borrow(|m| f(m.as_ref().unwrap())) } -pub fn with_map_mut, Memory>) -> R, R>(f: F) -> R { - MAP.with_borrow_mut(|m| f(&mut m.as_mut().unwrap().map)) +pub fn with_map_mut R, R>(f: F) -> R { + MAP.with_borrow_mut(|m| f(m.as_mut().unwrap())) } -pub fn garbage_collect(prefix: KeyPrefix) -> Result { - // assert!(!prefix.is_empty()); +impl StableMemoryMap { + pub fn get(&self, key: K) -> Option> { + self.map.get(&key.into()) + } + pub fn insert(&mut self, key: K, value: Vec) { + self.map.insert(key.into(), value); + } + + pub fn remove(&mut self, key: K) -> Option> { + self.map.remove(&key.into()) + } + + pub fn range<'a, K: Key + 'a, R: RangeBounds>(&'a self, range: R) -> impl DoubleEndedIterator)> + 'a { + let start = map_bound(range.start_bound()); + let end = map_bound(range.end_bound()); + + Iter { + inner: self.map.range((start, end)), + _phantom: PhantomData, + } + } +} + +pub fn garbage_collect(prefix: BaseKeyPrefix) -> Result { let mut total_count = 0; with_map_mut(|m| { // If < 2B instructions have been used so far, delete another 100 keys, or exit if complete while ic_cdk::api::instruction_counter() < 2_000_000_000 { let keys: Vec<_> = m - .range(Key::from(prefix.clone())..) - .take_while(|(k, _)| k.starts_with(prefix.as_slice())) + .map + .range(BaseKey::from(prefix.clone())..) + .take_while(|(k, _)| k.starts_with(&prefix)) .map(|(k, _)| k) .take(100) .collect(); @@ -51,7 +76,7 @@ pub fn garbage_collect(prefix: KeyPrefix) -> Result { let batch_count = keys.len() as u32; total_count += batch_count; for key in keys { - m.remove(&key); + m.map.remove(&key); } // If batch count < 100 then we are finished if batch_count < 100 { @@ -61,3 +86,34 @@ pub fn garbage_collect(prefix: KeyPrefix) -> Result { Err(total_count) }) } + +fn map_bound(bound: Bound<&K>) -> Bound { + match bound { + Bound::Included(k) => Bound::Included(k.clone().into()), + Bound::Excluded(k) => Bound::Excluded(k.clone().into()), + Bound::Unbounded => Bound::Unbounded, + } +} + +struct Iter { + inner: I, + _phantom: PhantomData, +} + +impl)>> Iterator for Iter { + type Item = (K, Vec); + + fn next(&mut self) -> Option { + self.inner.next().and_then(try_map_key_value::) + } +} + +impl)>> DoubleEndedIterator for Iter { + fn next_back(&mut self) -> Option { + self.inner.next_back().and_then(try_map_key_value::) + } +} + +fn try_map_key_value((key, value): (BaseKey, Vec)) -> Option<(K, Vec)> { + K::try_from(key).ok().map(|k| (k, value)) +}