diff --git a/Cargo.lock b/Cargo.lock index fe78fc22be..a361a3475a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2907,7 +2907,6 @@ dependencies = [ "lazy_static", "msgpack", "proptest", - "rand 0.8.5", "regex-lite", "search", "serde", @@ -7290,6 +7289,12 @@ version = "0.1.0" dependencies = [ "ic-cdk 0.16.0", "ic-stable-structures", + "ic_principal", + "msgpack", + "rand 0.8.5", + "serde", + "serde_bytes", + "types", ] [[package]] diff --git a/backend/canisters/community/CHANGELOG.md b/backend/canisters/community/CHANGELOG.md index 19a1e61e65..51054d2288 100644 --- a/backend/canisters/community/CHANGELOG.md +++ b/backend/canisters/community/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Implement `MembersStableStorage` which stores members in stable memory ([#6931](https://github.com/open-chat-labs/open-chat/pull/6931)) - Migrate chat members to stable memory using timer job ([#6933](https://github.com/open-chat-labs/open-chat/pull/6933)) - Write group members to stable memory when importing group into community ([#6935](https://github.com/open-chat-labs/open-chat/pull/6935)) +- Make `StableMemoryMap` use strongly typed keys ([#6937](https://github.com/open-chat-labs/open-chat/pull/6937)) ## [[2.0.1479](https://github.com/open-chat-labs/open-chat/releases/tag/v2.0.1479-community)] - 2024-11-28 diff --git a/backend/canisters/community/impl/src/lib.rs b/backend/canisters/community/impl/src/lib.rs index 696d5154a4..58c2de480e 100644 --- a/backend/canisters/community/impl/src/lib.rs +++ b/backend/canisters/community/impl/src/lib.rs @@ -7,7 +7,7 @@ use activity_notification_state::ActivityNotificationState; use candid::Principal; use canister_state_macros::canister_state; use canister_timer_jobs::TimerJobs; -use chat_events::{ChannelThreadKeyPrefix, ChatMetricsInternal, KeyPrefix}; +use chat_events::ChatMetricsInternal; use community_canister::EventsResponse; use constants::MINUTE_IN_MS; use event_store_producer::{EventStoreClient, EventStoreClientBuilder, EventStoreClientInfo}; @@ -28,6 +28,7 @@ use rand::rngs::StdRng; use rand::RngCore; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; +use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix}; use std::cell::RefCell; use std::ops::Deref; use std::time::Duration; @@ -241,9 +242,12 @@ impl RuntimeState { } final_prize_payments.extend(result.final_prize_payments); for thread in result.threads { - self.data.stable_memory_keys_to_garbage_collect.push( - KeyPrefix::ChannelThread(ChannelThreadKeyPrefix::new(channel.id, thread.root_message_index)).to_vec(), - ); + self.data + .stable_memory_keys_to_garbage_collect + .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_channel( + channel.id, + Some(thread.root_message_index), + ))); } } jobs::garbage_collect_stable_memory::start_job_if_required(self); @@ -366,7 +370,7 @@ struct Data { user_cache: UserCache, user_event_sync_queue: GroupedTimerJobQueue, #[serde(default)] - stable_memory_keys_to_garbage_collect: Vec>, + stable_memory_keys_to_garbage_collect: Vec, #[serde(default)] members_migrated_to_stable_memory: bool, } diff --git a/backend/canisters/community/impl/src/lifecycle/post_upgrade.rs b/backend/canisters/community/impl/src/lifecycle/post_upgrade.rs index 80b698ee4c..9dbb020c75 100644 --- a/backend/canisters/community/impl/src/lifecycle/post_upgrade.rs +++ b/backend/canisters/community/impl/src/lifecycle/post_upgrade.rs @@ -11,7 +11,7 @@ use ic_cdk::post_upgrade; use instruction_counts_log::InstructionCountFunctionId; use stable_memory::get_reader; use tracing::info; -use types::MultiUserChat; +use types::{Chat, MultiUserChat}; #[post_upgrade] #[trace] @@ -26,6 +26,7 @@ fn post_upgrade(args: Args) { let community_id = ic_cdk::id().into(); for channel in data.channels.iter_mut() { + channel.chat.events.set_chat(Chat::Channel(community_id, channel.id)); channel.chat.members.set_member_default_timestamps(); channel .chat diff --git a/backend/canisters/community/impl/src/updates/delete_channel.rs b/backend/canisters/community/impl/src/updates/delete_channel.rs index 6b679872a3..f6616ea07f 100644 --- a/backend/canisters/community/impl/src/updates/delete_channel.rs +++ b/backend/canisters/community/impl/src/updates/delete_channel.rs @@ -4,8 +4,8 @@ use crate::{ }; use canister_api_macros::update; use canister_tracing_macros::trace; -use chat_events::{ChannelKeyPrefix, ChannelThreadKeyPrefix, KeyPrefix}; use community_canister::delete_channel::{Response::*, *}; +use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix, MemberKeyPrefix}; use types::{ChannelDeleted, ChannelId}; #[update(candid = true, msgpack = true)] @@ -54,19 +54,22 @@ fn delete_channel_impl(channel_id: ChannelId, state: &mut RuntimeState) -> Respo state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::Channel(ChannelKeyPrefix::new(channel_id)).to_vec()); + .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_channel(channel_id, None))); for message_index in channel.chat.events.thread_keys() { state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::ChannelThread(ChannelThreadKeyPrefix::new(channel_id, message_index)).to_vec()); + .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_channel( + channel_id, + Some(message_index), + ))); } state .data .stable_memory_keys_to_garbage_collect - .push(group_chat_core::MembersKeyPrefix::Channel(channel_id.as_u32()).to_vec()); + .push(KeyPrefix::from(MemberKeyPrefix::new_from_channel(channel_id))); crate::jobs::garbage_collect_stable_memory::start_job_if_required(state); diff --git a/backend/canisters/group/CHANGELOG.md b/backend/canisters/group/CHANGELOG.md index 07c02ba99e..05a8346143 100644 --- a/backend/canisters/group/CHANGELOG.md +++ b/backend/canisters/group/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Implement `MembersStableStorage` which stores members in stable memory ([#6931](https://github.com/open-chat-labs/open-chat/pull/6931)) - Migrate chat members to stable memory using timer job ([#6933](https://github.com/open-chat-labs/open-chat/pull/6933)) - Export members from stable memory when importing group into community ([#6935](https://github.com/open-chat-labs/open-chat/pull/6935)) +- Make `StableMemoryMap` use strongly typed keys ([#6937](https://github.com/open-chat-labs/open-chat/pull/6937)) ## [[2.0.1480](https://github.com/open-chat-labs/open-chat/releases/tag/v2.0.1480-group)] - 2024-11-28 diff --git a/backend/canisters/group/impl/src/lib.rs b/backend/canisters/group/impl/src/lib.rs index e91e06f580..9bbe1b8813 100644 --- a/backend/canisters/group/impl/src/lib.rs +++ b/backend/canisters/group/impl/src/lib.rs @@ -7,7 +7,7 @@ use activity_notification_state::ActivityNotificationState; use candid::Principal; use canister_state_macros::canister_state; use canister_timer_jobs::TimerJobs; -use chat_events::{GroupChatThreadKeyPrefix, KeyPrefix, Reader}; +use chat_events::Reader; use constants::{DAY_IN_MS, HOUR_IN_MS, MINUTE_IN_MS, OPENCHAT_BOT_USER_ID}; use event_store_producer::{EventStoreClient, EventStoreClientBuilder, EventStoreClientInfo}; use event_store_producer_cdk_runtime::CdkRuntime; @@ -24,6 +24,7 @@ use msgpack::serialize_then_unwrap; use notifications_canister::c2c_push_notification; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; +use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix}; use std::cell::RefCell; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::{HashMap, HashSet}; @@ -361,7 +362,9 @@ impl RuntimeState { for thread in result.threads { self.data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::GroupChatThread(GroupChatThreadKeyPrefix::new(thread.root_message_index)).to_vec()); + .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_group_chat(Some( + thread.root_message_index, + )))); } jobs::garbage_collect_stable_memory::start_job_if_required(self); } @@ -474,7 +477,7 @@ struct Data { user_event_sync_queue: GroupedTimerJobQueue, #[serde(default)] members_migrated_to_stable_memory: bool, - stable_memory_keys_to_garbage_collect: Vec>, + stable_memory_keys_to_garbage_collect: Vec, } fn init_instruction_counts_log() -> InstructionCountsLog { diff --git a/backend/canisters/group/impl/src/lifecycle/post_upgrade.rs b/backend/canisters/group/impl/src/lifecycle/post_upgrade.rs index d87ffddc46..f1f3d572e0 100644 --- a/backend/canisters/group/impl/src/lifecycle/post_upgrade.rs +++ b/backend/canisters/group/impl/src/lifecycle/post_upgrade.rs @@ -10,7 +10,7 @@ use ic_cdk::post_upgrade; use instruction_counts_log::InstructionCountFunctionId; use stable_memory::get_reader; use tracing::info; -use types::MultiUserChat; +use types::{Chat, MultiUserChat}; #[post_upgrade] #[trace] @@ -23,8 +23,10 @@ fn post_upgrade(args: Args) { let (mut data, errors, logs, traces): (Data, Vec, Vec, Vec) = msgpack::deserialize(reader).unwrap(); + let chat_id = ic_cdk::id().into(); + data.chat.events.set_chat(Chat::Group(chat_id)); data.chat.members.set_member_default_timestamps(); - data.chat.members.set_chat(MultiUserChat::Group(ic_cdk::id().into())); + data.chat.members.set_chat(MultiUserChat::Group(chat_id)); canister_logger::init_with_logs(data.test_mode, errors, logs, traces); diff --git a/backend/canisters/user/CHANGELOG.md b/backend/canisters/user/CHANGELOG.md index e2d2736cd7..01dca76b29 100644 --- a/backend/canisters/user/CHANGELOG.md +++ b/backend/canisters/user/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Extract stable memory map so it can store additional datasets ([#6876](https://github.com/open-chat-labs/open-chat/pull/6876)) - Make `ChannelId` comparisons use their 32bit representation ([#6885](https://github.com/open-chat-labs/open-chat/pull/6885)) - Remove chat event updates after 31 days ([#6916](https://github.com/open-chat-labs/open-chat/pull/6916)) +- Make `StableMemoryMap` use strongly typed keys ([#6937](https://github.com/open-chat-labs/open-chat/pull/6937)) ### Removed diff --git a/backend/canisters/user/impl/src/lib.rs b/backend/canisters/user/impl/src/lib.rs index eb15a9d6db..2db243d273 100644 --- a/backend/canisters/user/impl/src/lib.rs +++ b/backend/canisters/user/impl/src/lib.rs @@ -25,6 +25,7 @@ use model::streak::Streak; use notifications_canister::c2c_push_notification; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; +use stable_memory_map::KeyPrefix; use std::cell::RefCell; use std::collections::HashSet; use std::ops::Deref; @@ -255,7 +256,7 @@ struct Data { pub referred_by: Option, pub referrals: Referrals, pub message_activity_events: MessageActivityEvents, - pub stable_memory_keys_to_garbage_collect: Vec>, + pub stable_memory_keys_to_garbage_collect: Vec, } impl Data { diff --git a/backend/canisters/user/impl/src/lifecycle/post_upgrade.rs b/backend/canisters/user/impl/src/lifecycle/post_upgrade.rs index f30038369d..b99e91d85f 100644 --- a/backend/canisters/user/impl/src/lifecycle/post_upgrade.rs +++ b/backend/canisters/user/impl/src/lifecycle/post_upgrade.rs @@ -1,12 +1,13 @@ use crate::lifecycle::{init_env, init_state}; use crate::memory::{get_stable_memory_map_memory, get_upgrades_memory}; use crate::{mutate_state, Data}; +use candid::Principal; use canister_logger::LogEntry; use canister_tracing_macros::trace; use ic_cdk::post_upgrade; use stable_memory::get_reader; use tracing::info; -use types::CanisterId; +use types::{CanisterId, Chat}; use user_canister::post_upgrade::Args; #[post_upgrade] @@ -38,6 +39,7 @@ fn post_upgrade(args: Args) { mutate_state(|state| { let now = state.env.now(); for chat in state.data.direct_chats.iter_mut() { + chat.events.set_chat(Chat::Direct(Principal::from(chat.them).into())); chat.events.remove_spurious_video_call_in_progress(now); let count_removed = chat.events.prune_updated_events(now); diff --git a/backend/canisters/user/impl/src/updates/delete_direct_chat.rs b/backend/canisters/user/impl/src/updates/delete_direct_chat.rs index ed18fae7da..28095fc371 100644 --- a/backend/canisters/user/impl/src/updates/delete_direct_chat.rs +++ b/backend/canisters/user/impl/src/updates/delete_direct_chat.rs @@ -2,7 +2,7 @@ use crate::guards::caller_is_owner; use crate::{mutate_state, run_regular_jobs, RuntimeState}; use canister_api_macros::update; use canister_tracing_macros::trace; -use chat_events::{DirectChatKeyPrefix, DirectChatThreadKeyPrefix, KeyPrefix}; +use stable_memory_map::{ChatEventKeyPrefix, KeyPrefix}; use user_canister::delete_direct_chat::{Response::*, *}; #[update(guard = "caller_is_owner", msgpack = true)] @@ -19,16 +19,22 @@ fn delete_direct_chat_impl(args: Args, state: &mut RuntimeState) -> Response { if args.block_user { state.data.block_user(args.user_id, now); } + state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::DirectChat(DirectChatKeyPrefix::new(args.user_id)).to_vec()); + .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_direct_chat(args.user_id, None))); + for message_index in chat.events.thread_keys() { state .data .stable_memory_keys_to_garbage_collect - .push(KeyPrefix::DirectChatThread(DirectChatThreadKeyPrefix::new(args.user_id, message_index)).to_vec()); + .push(KeyPrefix::from(ChatEventKeyPrefix::new_from_direct_chat( + args.user_id, + Some(message_index), + ))); } + crate::jobs::garbage_collect_stable_memory::start_job_if_required(state); Success } else { diff --git a/backend/libraries/chat_events/src/lib.rs b/backend/libraries/chat_events/src/lib.rs index a4cf2b48ca..0710346988 100644 --- a/backend/libraries/chat_events/src/lib.rs +++ b/backend/libraries/chat_events/src/lib.rs @@ -17,4 +17,3 @@ pub use crate::chat_events_list::*; pub use crate::events_map::*; pub use crate::message_content_internal::*; pub use crate::metrics::*; -pub use crate::stable_memory::key::*; diff --git a/backend/libraries/chat_events/src/stable_memory/key.rs b/backend/libraries/chat_events/src/stable_memory/key.rs deleted file mode 100644 index 30727394d1..0000000000 --- a/backend/libraries/chat_events/src/stable_memory/key.rs +++ /dev/null @@ -1,340 +0,0 @@ -use candid::Principal; -use ic_stable_structures::storable::Bound; -use ic_stable_structures::Storable; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use stable_memory_map::KeyType; -use std::borrow::Cow; -use types::{CanisterId, ChannelId, Chat, EventIndex, MessageIndex, UserId}; - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct Key { - prefix: KeyPrefix, - event_index: EventIndex, -} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub enum KeyPrefix { - DirectChat(DirectChatKeyPrefix), - GroupChat(GroupChatKeyPrefix), - Channel(ChannelKeyPrefix), - DirectChatThread(DirectChatThreadKeyPrefix), - GroupChatThread(GroupChatThreadKeyPrefix), - ChannelThread(ChannelThreadKeyPrefix), -} - -impl Key { - pub fn new(prefix: KeyPrefix, event_index: EventIndex) -> Key { - Key { prefix, event_index } - } - - pub fn prefix(&self) -> &KeyPrefix { - &self.prefix - } - - pub fn event_index(&self) -> EventIndex { - self.event_index - } - - pub fn thread_root_message_index(&self) -> Option { - self.prefix.thread_root_message_index() - } - - pub fn matches_chat(&self, chat: Chat) -> bool { - self.prefix.matches_chat(chat) - } - - pub fn key_type(&self) -> KeyType { - self.prefix.key_type() - } -} - -impl KeyPrefix { - pub fn new(chat: Chat, thread_root_message_index: Option) -> KeyPrefix { - match (chat, thread_root_message_index) { - (Chat::Direct(c), None) => KeyPrefix::DirectChat(DirectChatKeyPrefix::new(Principal::from(c).into())), - (Chat::Direct(c), Some(m)) => { - KeyPrefix::DirectChatThread(DirectChatThreadKeyPrefix::new(Principal::from(c).into(), m)) - } - (Chat::Group(_), None) => KeyPrefix::GroupChat(GroupChatKeyPrefix::default()), - (Chat::Group(_), Some(m)) => KeyPrefix::GroupChatThread(GroupChatThreadKeyPrefix::new(m)), - (Chat::Channel(_, c), None) => KeyPrefix::Channel(ChannelKeyPrefix::new(c)), - (Chat::Channel(_, c), Some(m)) => KeyPrefix::ChannelThread(ChannelThreadKeyPrefix::new(c, m)), - } - } - - pub fn matches_chat(&self, chat: Chat) -> bool { - match self { - KeyPrefix::DirectChat(k) => matches!(chat, Chat::Direct(c) if CanisterId::from(c) == k.user_id.0), - KeyPrefix::GroupChat(_) => matches!(chat, Chat::Group(_)), - KeyPrefix::Channel(k) => matches!(chat, Chat::Channel(_, c) if c == k.channel_id.into()), - KeyPrefix::DirectChatThread(k) => matches!(chat, Chat::Direct(c) if CanisterId::from(c) == k.user_id.0), - KeyPrefix::GroupChatThread(_) => matches!(chat, Chat::Group(_)), - KeyPrefix::ChannelThread(k) => matches!(chat, Chat::Channel(_, c) if c == k.channel_id.into()), - } - } - - pub fn thread_root_message_index(&self) -> Option { - match self { - KeyPrefix::DirectChat(_) | KeyPrefix::GroupChat(_) | KeyPrefix::Channel(_) => None, - KeyPrefix::DirectChatThread(k) => Some(k.thread_root_message_index.into()), - KeyPrefix::GroupChatThread(k) => Some(k.thread_root_message_index.into()), - KeyPrefix::ChannelThread(k) => Some(k.thread_root_message_index.into()), - } - } - - pub fn key_type(&self) -> KeyType { - match self { - KeyPrefix::DirectChat(_) => KeyType::DirectChatEvent, - KeyPrefix::GroupChat(_) => KeyType::GroupChatEvent, - KeyPrefix::Channel(_) => KeyType::ChannelEvent, - KeyPrefix::DirectChatThread(_) => KeyType::DirectChatThreadEvent, - KeyPrefix::GroupChatThread(_) => KeyType::GroupChatThreadEvent, - KeyPrefix::ChannelThread(_) => KeyType::ChannelThreadEvent, - } - } -} - -impl Key { - pub fn to_vec(&self) -> Vec { - let mut bytes = self.prefix.to_vec(); - bytes.extend_from_slice(&u32::from(self.event_index).to_be_bytes()); - bytes - } -} - -impl TryFrom<&[u8]> for Key { - type Error = (); - - fn try_from(bytes: &[u8]) -> Result { - let len = bytes.len(); - let prefix = KeyPrefix::try_from(&bytes[..len - 4])?; - let event_index = u32::from_be_bytes(bytes[(len - 4)..].try_into().unwrap()).into(); - Ok(Key { prefix, event_index }) - } -} - -impl KeyPrefix { - pub fn to_vec(&self) -> Vec { - let mut bytes = Vec::new(); - bytes.push(self.key_type() as u8); - bytes.extend_from_slice( - match self { - KeyPrefix::DirectChat(k) => k.to_bytes(), - KeyPrefix::GroupChat(k) => k.to_bytes(), - KeyPrefix::Channel(k) => k.to_bytes(), - KeyPrefix::DirectChatThread(k) => k.to_bytes(), - KeyPrefix::GroupChatThread(k) => k.to_bytes(), - KeyPrefix::ChannelThread(k) => k.to_bytes(), - } - .as_ref(), - ); - bytes - } -} - -impl TryFrom<&[u8]> for KeyPrefix { - type Error = (); - - fn try_from(bytes: &[u8]) -> Result { - let key_type = KeyType::from(bytes[0]); - let bytes = Cow::Borrowed(&bytes[1..]); - - match key_type { - KeyType::DirectChatEvent => Ok(KeyPrefix::DirectChat(DirectChatKeyPrefix::from_bytes(bytes))), - KeyType::GroupChatEvent => Ok(KeyPrefix::GroupChat(GroupChatKeyPrefix::from_bytes(bytes))), - KeyType::ChannelEvent => Ok(KeyPrefix::Channel(ChannelKeyPrefix::from_bytes(bytes))), - KeyType::DirectChatThreadEvent => Ok(KeyPrefix::DirectChatThread(DirectChatThreadKeyPrefix::from_bytes(bytes))), - KeyType::GroupChatThreadEvent => Ok(KeyPrefix::GroupChatThread(GroupChatThreadKeyPrefix::from_bytes(bytes))), - KeyType::ChannelThreadEvent => Ok(KeyPrefix::ChannelThread(ChannelThreadKeyPrefix::from_bytes(bytes))), - _ => Err(()), - } - } -} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct DirectChatKeyPrefix { - user_id: CanisterIdWithSize, -} - -impl DirectChatKeyPrefix { - pub fn new(user_id: UserId) -> Self { - Self { - user_id: CanisterIdWithSize(user_id.into()), - } - } -} - -#[derive(Clone, Default, Eq, PartialEq, Ord, PartialOrd)] -pub struct GroupChatKeyPrefix {} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct ChannelKeyPrefix { - channel_id: u32, -} - -impl ChannelKeyPrefix { - pub fn new(channel_id: ChannelId) -> Self { - ChannelKeyPrefix { - channel_id: channel_id.as_u32(), - } - } -} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct DirectChatThreadKeyPrefix { - user_id: CanisterIdWithSize, - thread_root_message_index: u32, -} - -impl DirectChatThreadKeyPrefix { - pub fn new(user_id: UserId, thread_root_message_index: MessageIndex) -> Self { - Self { - user_id: CanisterIdWithSize(user_id.into()), - thread_root_message_index: thread_root_message_index.into(), - } - } -} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct GroupChatThreadKeyPrefix { - thread_root_message_index: u32, -} - -impl GroupChatThreadKeyPrefix { - pub fn new(thread_root_message_index: MessageIndex) -> Self { - Self { - thread_root_message_index: thread_root_message_index.into(), - } - } -} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct ChannelThreadKeyPrefix { - channel_id: u32, - thread_root_message_index: u32, -} - -impl ChannelThreadKeyPrefix { - pub fn new(channel_id: ChannelId, thread_root_message_index: MessageIndex) -> Self { - ChannelThreadKeyPrefix { - channel_id: channel_id.as_u32(), - thread_root_message_index: thread_root_message_index.into(), - } - } -} - -fn read_value(bytes: &mut &[u8]) -> T { - let size = T::size(bytes[0]); - let value = T::from_bytes(Cow::Borrowed(&bytes[..size])); - *bytes = &bytes[size..]; - value -} - -trait SizeFromReader { - fn size(next_byte: u8) -> usize; -} - -impl SizeFromReader for CanisterIdWithSize { - fn size(next_byte: u8) -> usize { - next_byte as usize + 1 - } -} - -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -struct CanisterIdWithSize(CanisterId); - -impl Storable for CanisterIdWithSize { - fn to_bytes(&self) -> Cow<[u8]> { - let canister_id_len = self.0.as_ref().len(); - let mut bytes = Vec::with_capacity(canister_id_len + 1); - bytes.push(canister_id_len as u8); - bytes.extend_from_slice(self.0.as_ref()); - Cow::Owned(bytes) - } - - fn from_bytes(bytes: Cow<[u8]>) -> Self { - CanisterIdWithSize(CanisterId::from_slice(&bytes[1..])) - } - - const BOUND: Bound = Bound::Bounded { - max_size: CanisterId::BOUND.max_size() + 1, - is_fixed_size: false, - }; -} -macro_rules! storable_as_tuple { - ($ty:ident) => { - impl Storable for $ty { - fn to_bytes(&self) -> Cow<[u8]> { - Cow::Owned(Vec::new()) - } - - fn from_bytes(_bytes: Cow<[u8]>) -> Self { - Self {} - } - - const BOUND: Bound = Bound::Bounded { is_fixed_size: true, max_size: 0 }; - } - }; - ($ty:ident, $($field:ident),+) => { - impl Storable for $ty { - fn to_bytes(&self) -> Cow<[u8]> { - let mut bytes = Vec::new(); - $( - bytes.extend_from_slice(self.$field.to_bytes().as_ref()); - )* - Cow::Owned(bytes) - } - - fn from_bytes(bytes: Cow<[u8]>) -> Self { - let mut slice = bytes.as_ref(); - - Self { - $( - $field: read_value(&mut slice), - )* - } - } - - const BOUND: Bound = Bound::Unbounded; - } - }; -} - -storable_as_tuple!(DirectChatKeyPrefix, user_id); -storable_as_tuple!(GroupChatKeyPrefix); -storable_as_tuple!(ChannelKeyPrefix, channel_id); -storable_as_tuple!(DirectChatThreadKeyPrefix, user_id, thread_root_message_index); -storable_as_tuple!(GroupChatThreadKeyPrefix, thread_root_message_index); -storable_as_tuple!(ChannelThreadKeyPrefix, channel_id, thread_root_message_index); - -macro_rules! size_from_reader_fixed { - ($ty:ident) => { - impl SizeFromReader for $ty { - fn size(_: u8) -> usize { - size_of::<$ty>() - } - } - }; -} - -size_from_reader_fixed!(u32); - -impl Serialize for KeyPrefix { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let bytes = self.to_vec(); - serializer.serialize_bytes(&bytes) - } -} - -impl<'de> Deserialize<'de> for KeyPrefix { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let bytes: Vec = Vec::deserialize(deserializer)?; - Ok(KeyPrefix::try_from(bytes.as_slice()).unwrap()) - } -} diff --git a/backend/libraries/chat_events/src/stable_memory/mod.rs b/backend/libraries/chat_events/src/stable_memory/mod.rs index 6aa1b846fb..c52994693e 100644 --- a/backend/libraries/chat_events/src/stable_memory/mod.rs +++ b/backend/libraries/chat_events/src/stable_memory/mod.rs @@ -1,8 +1,7 @@ -use crate::stable_memory::key::{Key, KeyPrefix}; use crate::{ChatEventInternal, EventsMap}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; -use stable_memory_map::{with_map, with_map_mut}; +use stable_memory_map::{with_map, with_map_mut, ChatEventKey, ChatEventKeyPrefix, Key}; use std::cmp::min; use std::collections::VecDeque; use std::ops::RangeBounds; @@ -10,26 +9,24 @@ use types::{ Chat, EventContext, EventIndex, EventWrapperInternal, MessageIndex, TimestampMillis, MAX_EVENT_INDEX, MIN_EVENT_INDEX, }; -pub mod key; - #[cfg(test)] mod tests; // Used to efficiently read all events from stable memory when migrating a group into a community pub fn read_events_as_bytes(chat: Chat, after: Option, max_bytes: usize) -> Vec<(EventContext, ByteBuf)> { let key = match after { - None => Key::new(KeyPrefix::new(chat, None), EventIndex::default()), + None => ChatEventKeyPrefix::new_from_chat(chat, None).create_key(EventIndex::default()), Some(EventContext { thread_root_message_index, event_index, - }) => Key::new(KeyPrefix::new(chat, thread_root_message_index), event_index.incr()), + }) => ChatEventKeyPrefix::new_from_chat(chat, thread_root_message_index).create_key(event_index.incr()), }; with_map(|m| { let mut total_bytes = 0; - m.range(key.to_vec()..) - .map_while(|(k, v)| Key::try_from(k.as_slice()).ok().map(|k| (k, v))) + m.range(Key::from(key)..) + .map_while(|(k, v)| ChatEventKey::try_from(k).ok().map(|k| (k, v))) .take_while(|(k, v)| { - if !k.matches_chat(chat) { + if !k.matches_chat(&chat) { return false; } total_bytes += v.len(); @@ -44,32 +41,33 @@ pub fn read_events_as_bytes(chat: Chat, after: Option, max_bytes: pub fn write_events_as_bytes(chat: Chat, events: Vec<(EventContext, ByteBuf)>) { with_map_mut(|m| { for (context, bytes) in events { - let prefix = KeyPrefix::new(chat, context.thread_root_message_index); - let key = Key::new(prefix, context.event_index).to_vec(); + let prefix = ChatEventKeyPrefix::new_from_chat(chat, context.thread_root_message_index); + let key = prefix.create_key(context.event_index); let value = bytes.into_vec(); // Check the event is valid. We could remove this once we're more confident let _ = bytes_to_event(&value); - m.insert(key, value); + m.insert(key.into(), value); } }); } #[derive(Serialize, Deserialize)] pub struct ChatEventsStableStorage { - prefix: KeyPrefix, + #[serde(skip_deserializing, default = "default_key_prefix")] + prefix: ChatEventKeyPrefix, +} + +fn default_key_prefix() -> ChatEventKeyPrefix { + ChatEventKeyPrefix::new_from_group_chat(None) } impl ChatEventsStableStorage { pub fn new(chat: Chat, thread_root_message_index: Option) -> Self { ChatEventsStableStorage { - prefix: KeyPrefix::new(chat, thread_root_message_index), + prefix: ChatEventKeyPrefix::new_from_chat(chat, thread_root_message_index), } } - fn key(&self, event_index: EventIndex) -> Key { - Key::new(self.prefix.clone(), event_index) - } - fn iter_as_bytes(&self) -> Iter { Iter::new(self.prefix.clone(), MIN_EVENT_INDEX, MAX_EVENT_INDEX) } @@ -92,8 +90,8 @@ impl ChatEventsStableStorage { } fn get_internal(&self, event_index: EventIndex) -> Option> { - let key = self.key(event_index); - with_map(|m| m.get(&key.to_vec())) + let key = self.prefix.create_key(event_index); + with_map(|m| m.get(&Key::from(key))) } } @@ -107,13 +105,11 @@ impl EventsMap for ChatEventsStableStorage { } fn insert(&mut self, event: EventWrapperInternal) { - let key = self.key(event.index); - with_map_mut(|m| m.insert(key.to_vec(), event_to_bytes(event))); + with_map_mut(|m| m.insert(Key::from(self.prefix.create_key(event.index)), event_to_bytes(event))); } fn remove(&mut self, event_index: EventIndex) -> Option> { - let key = self.key(event_index); - with_map_mut(|m| m.remove(&key.to_vec())).map(|v| bytes_to_event(&v)) + with_map_mut(|m| m.remove(&Key::from(self.prefix.create_key(event_index)))).map(|v| bytes_to_event(&v)) } fn range>( @@ -157,7 +153,7 @@ const DEFAULT_BUFFER_SIZE: usize = 20; const MAX_BUFFER_SIZE: usize = 1000; struct Iter { - prefix: KeyPrefix, + prefix: ChatEventKeyPrefix, next: EventIndex, next_back: EventIndex, is_forward_buffer: bool, @@ -167,7 +163,7 @@ struct Iter { } impl Iter { - fn new(prefix: KeyPrefix, start: EventIndex, end: EventIndex) -> Self { + fn new(prefix: ChatEventKeyPrefix, start: EventIndex, end: EventIndex) -> Self { Iter { prefix, next: start, @@ -179,7 +175,7 @@ impl Iter { } } - fn empty(prefix: KeyPrefix) -> Iter { + fn empty(prefix: ChatEventKeyPrefix) -> Iter { Iter { prefix, next: EventIndex::default(), @@ -191,12 +187,12 @@ impl Iter { } } - fn next_key(&self) -> Key { - Key::new(self.prefix.clone(), self.next) + fn next_key(&self) -> ChatEventKey { + self.prefix.create_key(self.next) } - fn next_back_key(&self) -> Key { - Key::new(self.prefix.clone(), self.next_back) + fn next_back_key(&self) -> ChatEventKey { + self.prefix.create_key(self.next_back) } fn check_buffer_direction(&mut self, forward: bool) { @@ -236,8 +232,8 @@ impl Iterator for Iter { self.check_buffer_direction(true); if self.buffer.is_empty() { self.buffer = with_map(|m| { - m.range(self.next_key().to_vec()..=self.next_back_key().to_vec()) - .map_while(|(k, v)| Key::try_from(k.as_slice()).ok().map(|k| (k.event_index(), v))) + m.range(Key::from(self.next_key())..=Key::from(self.next_back_key())) + .map_while(|(k, v)| ChatEventKey::try_from(k).ok().map(|k| (k.event_index(), v))) .take(self.next_buffer_size) .collect() }); @@ -261,9 +257,9 @@ impl DoubleEndedIterator for Iter { self.check_buffer_direction(false); if self.buffer.is_empty() { self.buffer = with_map(|m| { - m.range(self.next_key().to_vec()..=self.next_back_key().to_vec()) + m.range(Key::from(self.next_key())..=Key::from(self.next_back_key())) .rev() - .map_while(|(k, v)| Key::try_from(k.as_slice()).ok().map(|k| (k.event_index(), v))) + .map_while(|(k, v)| ChatEventKey::try_from(k).ok().map(|k| (k.event_index(), v))) .take(self.next_buffer_size) .collect() }); diff --git a/backend/libraries/group_chat_core/Cargo.toml b/backend/libraries/group_chat_core/Cargo.toml index ec847213c2..d527d1ce11 100644 --- a/backend/libraries/group_chat_core/Cargo.toml +++ b/backend/libraries/group_chat_core/Cargo.toml @@ -28,5 +28,4 @@ utils = { path = "../utils" } ic-stable-structures = { workspace = true } msgpack = { path = "../msgpack" } proptest = { workspace = true } -rand = { workspace = true } test-strategy = { workspace = true } diff --git a/backend/libraries/group_chat_core/src/members.rs b/backend/libraries/group_chat_core/src/members.rs index b5b83ee859..d79c8cfca5 100644 --- a/backend/libraries/group_chat_core/src/members.rs +++ b/backend/libraries/group_chat_core/src/members.rs @@ -23,8 +23,6 @@ use utils::timestamped_set::TimestampedSet; mod proptests; mod stable_memory; -pub use stable_memory::KeyPrefix as MembersKeyPrefix; - const MAX_MEMBERS_PER_GROUP: u32 = 100_000; #[derive(Serialize, Deserialize)] diff --git a/backend/libraries/group_chat_core/src/members/stable_memory.rs b/backend/libraries/group_chat_core/src/members/stable_memory.rs index 43a660d584..edef80428f 100644 --- a/backend/libraries/group_chat_core/src/members/stable_memory.rs +++ b/backend/libraries/group_chat_core/src/members/stable_memory.rs @@ -1,50 +1,50 @@ use crate::members_map::MembersMap; use crate::GroupMemberInternal; use candid::{Deserialize, Principal}; -use serde::de::{Error, Visitor}; -use serde::{Deserializer, Serialize, Serializer}; +use serde::Serialize; use serde_bytes::ByteBuf; -use stable_memory_map::{with_map, with_map_mut, KeyType}; -use std::fmt::Formatter; +use stable_memory_map::{with_map, with_map_mut, Key, MemberKey, MemberKeyPrefix}; use types::{MultiUserChat, UserId}; #[derive(Serialize, Deserialize)] pub struct MembersStableStorage { - prefix: KeyPrefix, + prefix: MemberKeyPrefix, } impl MembersStableStorage { // TODO delete this after next upgrade pub fn new_empty() -> Self { MembersStableStorage { - prefix: KeyPrefix::GroupChat, + prefix: MemberKeyPrefix::new_from_chat(MultiUserChat::Group(Principal::anonymous().into())), } } #[allow(dead_code)] pub fn new(chat: MultiUserChat, member: GroupMemberInternal) -> Self { - let mut map = MembersStableStorage { prefix: chat.into() }; + let mut map = MembersStableStorage { + prefix: MemberKeyPrefix::new_from_chat(chat), + }; map.insert(member); map } pub fn set_chat(&mut self, chat: MultiUserChat) { - self.prefix = chat.into(); + self.prefix = MemberKeyPrefix::new_from_chat(chat); } // Used to efficiently read all members from stable memory when migrating a group into a community pub fn read_members_as_bytes(&self, after: Option, max_bytes: usize) -> Vec { let start_key = match after { - None => self.key(Principal::from_slice(&[]).into()), - Some(user_id) => self.key(user_id), - } - .to_vec(); + None => self.prefix.create_key(Principal::from_slice(&[]).into()), + Some(user_id) => self.prefix.create_key(user_id), + }; with_map(|m| { let mut total_bytes = 0; - m.range(start_key.clone()..) + m.range(Key::from(start_key.clone())..) + .map_while(|(k, v)| MemberKey::try_from(k).ok().map(|k| (k, v))) .skip_while(|(k, _)| *k == start_key) - .take_while(|(k, _)| KeyPrefix::try_from(k.as_slice()).is_ok_and(|p| p == self.prefix)) + .take_while(|(k, _)| k.matches_prefix(&self.prefix)) .map(|(_, v)| ByteBuf::from(v)) .take_while(|v| { total_bytes += v.len(); @@ -53,30 +53,30 @@ impl MembersStableStorage { .collect() }) } - - fn key(&self, user_id: UserId) -> Key { - Key::new(self.prefix, user_id) - } } impl MembersMap for MembersStableStorage { fn get(&self, user_id: &UserId) -> Option { - with_map(|m| m.get(&self.key(*user_id).to_vec()).map(|v| bytes_to_member(&v))) + with_map(|m| m.get(&self.prefix.create_key(*user_id).into()).map(|v| bytes_to_member(&v))) } fn insert(&mut self, member: GroupMemberInternal) { - with_map_mut(|m| m.insert(self.key(member.user_id).to_vec(), member_to_bytes(&member))); + with_map_mut(|m| m.insert(self.prefix.create_key(member.user_id).into(), member_to_bytes(&member))); } fn remove(&mut self, user_id: &UserId) -> Option { - with_map_mut(|m| m.remove(&self.key(*user_id).to_vec()).map(|v| bytes_to_member(&v))) + with_map_mut(|m| { + m.remove(&self.prefix.create_key(*user_id).into()) + .map(|v| bytes_to_member(&v)) + }) } #[cfg(test)] fn all_members(&self) -> Vec { with_map(|m| { - m.range(self.key(Principal::from_slice(&[]).into()).to_vec()..) - .take_while(|(k, _)| Key::try_from(k.as_slice()).ok().filter(|k| k.prefix == self.prefix).is_some()) + m.range(Key::from(self.prefix.create_key(Principal::from_slice(&[]).into()))..) + .map_while(|(k, v)| MemberKey::try_from(k).ok().map(|k| (k, v))) + .take_while(|(k, _)| k.matches_prefix(&self.prefix)) .map(|(_, v)| bytes_to_member(&v)) .collect() }) @@ -85,14 +85,14 @@ impl MembersMap for MembersStableStorage { // Used to write all members to stable memory when migrating a group into a community pub fn write_members_from_bytes(chat: MultiUserChat, members: Vec) -> Option { - let prefix = chat.into(); + let prefix = MemberKeyPrefix::new_from_chat(chat); let mut latest = None; with_map_mut(|m| { for byte_buf in members { let bytes = byte_buf.into_vec(); let member = bytes_to_member(&bytes); latest = Some(member.user_id); - m.insert(Key::new(prefix, member.user_id).to_vec(), bytes); + m.insert(prefix.create_key(member.user_id).into(), bytes); } }); latest @@ -105,178 +105,3 @@ fn member_to_bytes(member: &GroupMemberInternal) -> Vec { fn bytes_to_member(bytes: &[u8]) -> GroupMemberInternal { msgpack::deserialize_then_unwrap(bytes) } - -#[derive(Eq, PartialEq, Ord, PartialOrd, Debug)] -pub struct Key { - prefix: KeyPrefix, - user_id: UserId, -} - -impl Key { - fn new(prefix: KeyPrefix, user_id: UserId) -> Self { - Self { prefix, user_id } - } -} - -impl Key { - fn to_vec(&self) -> Vec { - let user_id_bytes = self.user_id.as_slice(); - let mut bytes = Vec::with_capacity(self.prefix.byte_len() + user_id_bytes.len()); - bytes.extend_from_slice(&self.prefix.to_vec()); - bytes.extend_from_slice(user_id_bytes); - bytes - } -} - -impl TryFrom<&[u8]> for Key { - type Error = (); - - fn try_from(value: &[u8]) -> Result { - let prefix = KeyPrefix::try_from(value)?; - let prefix_bytes_len = prefix.byte_len(); - let user_id = Principal::from_slice(&value[prefix_bytes_len..]).into(); - - Ok(Key::new(prefix, user_id)) - } -} - -#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Debug)] -pub enum KeyPrefix { - GroupChat, - Channel(u32), -} - -impl KeyPrefix { - pub fn to_vec(self) -> Vec { - match self { - KeyPrefix::GroupChat => vec![KeyType::ChatMember as u8, 1], - KeyPrefix::Channel(channel_id) => { - let mut vec = Vec::with_capacity(6); - vec.push(KeyType::ChatMember as u8); - vec.push(2); - vec.extend_from_slice(&channel_id.to_be_bytes()); - vec - } - } - } - - fn byte_len(&self) -> usize { - match self { - KeyPrefix::GroupChat => 2, - KeyPrefix::Channel(_) => 6, - } - } -} - -impl TryFrom<&[u8]> for KeyPrefix { - type Error = (); - - // The slice may extend beyond the bytes of the prefix - fn try_from(value: &[u8]) -> Result { - match value.split_first() { - Some((kt, bytes)) if *kt == KeyType::ChatMember as u8 => match bytes.split_first() { - Some((1, _)) => Ok(KeyPrefix::GroupChat), - Some((2, tail)) if tail.len() >= 4 => Ok(KeyPrefix::Channel(u32::from_be_bytes(tail[..4].try_into().unwrap()))), - _ => Err(()), - }, - _ => Err(()), - } - } -} - -impl From for KeyPrefix { - fn from(value: MultiUserChat) -> Self { - match value { - MultiUserChat::Group(_) => KeyPrefix::GroupChat, - MultiUserChat::Channel(_, c) => KeyPrefix::Channel(c.as_u32()), - } - } -} - -impl Serialize for KeyPrefix { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_bytes(&self.to_vec()) - } -} - -struct KeyPrefixVisitor; - -impl<'de> Visitor<'de> for KeyPrefixVisitor { - type Value = KeyPrefix; - - fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { - formatter.write_str("a byte array") - } - - fn visit_bytes(self, v: &[u8]) -> Result { - KeyPrefix::try_from(v).map_err(|_| E::custom("invalid key prefix")) - } -} - -impl<'de> Deserialize<'de> for KeyPrefix { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_bytes(KeyPrefixVisitor) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rand::{thread_rng, Rng, RngCore}; - - #[test] - fn group_key_roundtrip() { - for _ in 0..100 { - let user_id_bytes: [u8; 10] = thread_rng().gen(); - let user_id = Principal::from_slice(&user_id_bytes).into(); - - let key_in = Key::new(KeyPrefix::GroupChat, user_id); - let bytes = key_in.to_vec(); - let key_out = Key::try_from(bytes.as_slice()).unwrap(); - - assert_eq!(key_in, key_out); - } - } - - #[test] - fn channel_key_roundtrip() { - for _ in 0..100 { - let channel_id: u32 = thread_rng().next_u32(); - let user_id_bytes: [u8; 10] = thread_rng().gen(); - let user_id = Principal::from_slice(&user_id_bytes).into(); - - let key_in = Key::new(KeyPrefix::Channel(channel_id), user_id); - let bytes = key_in.to_vec(); - let key_out = Key::try_from(bytes.as_slice()).unwrap(); - - assert_eq!(key_in, key_out); - } - } - - #[test] - fn group_key_prefix_serialization_roundtrip() { - let key_prefix_in = KeyPrefix::GroupChat; - let bytes = msgpack::serialize_then_unwrap(key_prefix_in); - let key_prefix_out = msgpack::deserialize_then_unwrap(&bytes); - - assert_eq!(key_prefix_in, key_prefix_out); - } - - #[test] - fn channel_key_prefix_serialization_roundtrip() { - for _ in 0..100 { - let channel_id: u32 = thread_rng().next_u32(); - let key_prefix_in = KeyPrefix::Channel(channel_id); - let bytes = msgpack::serialize_then_unwrap(key_prefix_in); - let key_prefix_out = msgpack::deserialize_then_unwrap(&bytes); - - assert_eq!(key_prefix_in, key_prefix_out); - } - } -} diff --git a/backend/libraries/stable_memory_map/Cargo.toml b/backend/libraries/stable_memory_map/Cargo.toml index fd0b3d5455..a22b9e6e87 100644 --- a/backend/libraries/stable_memory_map/Cargo.toml +++ b/backend/libraries/stable_memory_map/Cargo.toml @@ -7,4 +7,12 @@ edition = "2021" [dependencies] ic-cdk = { workspace = true } +ic_principal = { workspace = true } ic-stable-structures = { workspace = true } +serde = { workspace = true } +serde_bytes = { workspace = true } +types = { path = "../types" } + +[dev-dependencies] +msgpack = { path = "../msgpack" } +rand = { workspace = true } diff --git a/backend/libraries/stable_memory_map/src/key.rs b/backend/libraries/stable_memory_map/src/key.rs new file mode 100644 index 0000000000..8b22f2face --- /dev/null +++ b/backend/libraries/stable_memory_map/src/key.rs @@ -0,0 +1,547 @@ +use ic_principal::Principal; +use ic_stable_structures::storable::Bound; +use ic_stable_structures::Storable; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use types::{ChannelId, Chat, EventIndex, MessageIndex, MultiUserChat, UserId}; + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(transparent)] +pub struct Key(#[serde(with = "serde_bytes")] Vec); + +impl Key { + pub fn starts_with(&self, prefix: &[u8]) -> bool { + self.0.starts_with(prefix) + } + + pub fn as_slice(&self) -> &[u8] { + &self.0 + } +} + +impl Storable for Key { + fn to_bytes(&self) -> Cow<[u8]> { + Cow::Borrowed(&self.0) + } + + fn from_bytes(bytes: Cow<[u8]>) -> Self { + Key(bytes.to_vec()) + } + + const BOUND: Bound = Bound::Unbounded; +} + +impl From for Key { + fn from(value: KeyPrefix) -> Self { + Key(value.0) + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(transparent)] +pub struct KeyPrefix(#[serde(with = "serde_bytes")] Vec); + +impl KeyPrefix { + pub fn as_slice(&self) -> &[u8] { + &self.0 + } +} + +// ChatEventKeyPrefix + EventIndex +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(into = "Key", try_from = "Key")] +pub struct ChatEventKey(Vec); + +// MemberKeyPrefix + UserId +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(into = "Key", try_from = "Key")] +pub struct MemberKey(Vec); + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(into = "KeyPrefix", try_from = "KeyPrefix")] +pub struct ChatEventKeyPrefix(Vec); + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[serde(into = "KeyPrefix", try_from = "KeyPrefix")] +pub struct MemberKeyPrefix(Vec); + +impl ChatEventKeyPrefix { + pub fn new_from_chat(chat: Chat, thread_root_message_index: Option) -> Self { + match chat { + Chat::Direct(user_id) => Self::new_from_direct_chat(Principal::from(user_id).into(), thread_root_message_index), + Chat::Group(_) => Self::new_from_group_chat(thread_root_message_index), + Chat::Channel(_, channel_id) => Self::new_from_channel(channel_id, thread_root_message_index), + } + } + + pub fn new_from_direct_chat(user_id: UserId, thread_root_message_index: Option) -> Self { + // We don't actually need the userId length marker but existing entries have it so we + // need to keep it to be backwards compatible. If we decide we want to remove it then we + // can switch to a new KeyType for direct chat events, but that is quite a lot of work + // and complexity since we'd still have to support the old version too (or migrate them). + + let user_id_bytes = user_id.as_slice(); + + match thread_root_message_index { + None => { + // KeyType::DirectChatThreadEvent 1 byte + // UserId length 1 byte + // UserId bytes UserId length bytes + let mut bytes = Vec::with_capacity(user_id_bytes.len() + 2); + bytes.push(KeyType::DirectChatEvent as u8); + bytes.push(user_id_bytes.len() as u8); + bytes.extend_from_slice(user_id_bytes); + ChatEventKeyPrefix(bytes) + } + Some(root_message_index) => { + // KeyType::DirectChatThreadEvent 1 byte + // UserId length 1 byte + // UserId bytes UserId length bytes + // Thread root message index 4 bytes + let mut bytes = Vec::with_capacity(user_id_bytes.len() + 6); + bytes.push(KeyType::DirectChatThreadEvent as u8); + bytes.push(user_id_bytes.len() as u8); + bytes.extend_from_slice(user_id_bytes); + bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + } + } + + pub fn new_from_group_chat(thread_root_message_index: Option) -> Self { + match thread_root_message_index { + None => { + // KeyType::GroupChatEvent 1 byte + ChatEventKeyPrefix(vec![KeyType::GroupChatEvent as u8]) + } + Some(root_message_index) => { + // KeyType::GroupChatThreadEvent 1 byte + // Thread root message index 4 bytes + let mut bytes = Vec::with_capacity(5); + bytes.push(KeyType::GroupChatThreadEvent as u8); + bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + } + } + + pub fn new_from_channel(channel_id: ChannelId, thread_root_message_index: Option) -> Self { + match thread_root_message_index { + None => { + // KeyType::ChannelEvent 1 byte + // ChannelId 4 bytes + let mut bytes = Vec::with_capacity(5); + bytes.push(KeyType::ChannelEvent as u8); + bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + Some(root_message_index) => { + // KeyType::ChannelThreadEvent 1 byte + // ChannelId 4 bytes + // Thread root message index 4 bytes + let mut bytes = Vec::with_capacity(9); + bytes.push(KeyType::ChannelThreadEvent as u8); + bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); + bytes.extend_from_slice(&u32::from(root_message_index).to_be_bytes()); + ChatEventKeyPrefix(bytes) + } + } + } + + pub fn create_key(&self, event_index: EventIndex) -> ChatEventKey { + let mut bytes = Vec::with_capacity(self.0.len() + 4); + bytes.extend_from_slice(self.0.as_slice()); + bytes.extend_from_slice(&u32::from(event_index).to_be_bytes()); + ChatEventKey(bytes) + } +} + +impl TryFrom for ChatEventKeyPrefix { + type Error = String; + + fn try_from(value: KeyPrefix) -> Result { + if extract_key_type(&value.0).is_some_and(|kt| kt.is_chat_event_key()) { + Ok(ChatEventKeyPrefix(value.0)) + } else { + Err(format!("Key type mismatch: {:?}", value.0.first())) + } + } +} + +impl TryFrom for ChatEventKey { + type Error = String; + + fn try_from(value: Key) -> Result { + if extract_key_type(&value.0).is_some_and(|kt| kt.is_chat_event_key()) { + Ok(ChatEventKey(value.0)) + } else { + Err(format!("Key type mismatch: {:?}", value.0.first())) + } + } +} + +impl From for KeyPrefix { + fn from(value: ChatEventKeyPrefix) -> Self { + KeyPrefix(value.0) + } +} + +impl From for Key { + fn from(value: ChatEventKey) -> Self { + Key(value.0) + } +} + +impl ChatEventKey { + pub fn matches_prefix(&self, prefix: &ChatEventKeyPrefix) -> bool { + self.0.starts_with(&prefix.0) + } + + pub fn matches_chat(&self, chat: &Chat) -> bool { + match (chat, self.key_type()) { + (Chat::Direct(id), KeyType::DirectChatEvent | KeyType::DirectChatThreadEvent) => { + let user_id_len = self.0[1] as usize; + let user_id_start = 2; + let user_id_end = user_id_start + user_id_len; + let user_id = Principal::from_slice(&self.0[user_id_start..user_id_end]).into(); + *id == user_id + } + (Chat::Group(_), KeyType::GroupChatEvent | KeyType::GroupChatThreadEvent) => true, + (Chat::Channel(_, id), KeyType::ChannelEvent | KeyType::ChannelThreadEvent) => { + let channel_id = u32::from_be_bytes(self.0[1..5].try_into().unwrap()).into(); + *id == channel_id + } + _ => false, + } + } + + pub fn thread_root_message_index(&self) -> Option { + if matches!( + self.key_type(), + KeyType::DirectChatThreadEvent | KeyType::GroupChatThreadEvent | KeyType::ChannelThreadEvent + ) { + let start = self.0.len() - 8; + let end = start + 4; + Some(u32::from_be_bytes(self.0[start..end].try_into().unwrap()).into()) + } else { + None + } + } + + pub fn event_index(&self) -> EventIndex { + let start = self.0.len() - 4; + u32::from_be_bytes(self.0[start..].try_into().unwrap()).into() + } + + fn key_type(&self) -> KeyType { + extract_key_type(&self.0).unwrap() + } +} + +impl MemberKeyPrefix { + pub fn new_from_chat(chat: MultiUserChat) -> Self { + match chat { + MultiUserChat::Group(_) => Self::new_from_group(), + MultiUserChat::Channel(_, channel_id) => Self::new_from_channel(channel_id), + } + } + + pub fn new_from_group() -> Self { + // KeyType::GroupMember 1 byte + MemberKeyPrefix(vec![KeyType::GroupMember as u8]) + } + + pub fn new_from_channel(channel_id: ChannelId) -> Self { + // KeyType::ChannelMember 1 byte + // ChannelId 4 bytes + let mut bytes = Vec::with_capacity(5); + bytes.push(KeyType::ChannelMember as u8); + bytes.extend_from_slice(&channel_id.as_u32().to_be_bytes()); + MemberKeyPrefix(bytes) + } + + pub fn new_from_community() -> Self { + // KeyType::CommunityMember 1 byte + MemberKeyPrefix(vec![KeyType::CommunityMember as u8]) + } + + pub fn create_key(&self, user_id: UserId) -> MemberKey { + let user_id_bytes = user_id.as_slice(); + let mut bytes = Vec::with_capacity(self.0.len() + user_id_bytes.len()); + bytes.extend_from_slice(self.0.as_slice()); + bytes.extend_from_slice(user_id_bytes); + MemberKey(bytes) + } +} + +impl From for KeyPrefix { + fn from(value: MemberKeyPrefix) -> Self { + KeyPrefix(value.0) + } +} + +impl From for Key { + fn from(value: MemberKey) -> Self { + Key(value.0) + } +} + +impl TryFrom for MemberKeyPrefix { + type Error = String; + + fn try_from(value: KeyPrefix) -> Result { + if extract_key_type(&value.0).is_some_and(|kt| kt.is_member_key()) { + Ok(MemberKeyPrefix(value.0)) + } else { + Err(format!("Key type mismatch: {:?}", value.0.first())) + } + } +} + +impl TryFrom for MemberKey { + type Error = String; + + fn try_from(value: Key) -> Result { + if extract_key_type(&value.0).is_some_and(|kt| kt.is_member_key()) { + Ok(MemberKey(value.0)) + } else { + Err(format!("Key type mismatch: {:?}", value.0.first())) + } + } +} + +impl MemberKey { + pub fn matches_prefix(&self, prefix: &MemberKeyPrefix) -> bool { + self.0.starts_with(&prefix.0) + } + + pub fn user_id(&self) -> UserId { + let prefix_len = match self.key_type() { + KeyType::GroupMember | KeyType::CommunityMember => 1, + KeyType::ChannelMember => 5, + _ => unreachable!(), + }; + Principal::from_slice(&self.0[prefix_len..]).into() + } + + fn key_type(&self) -> KeyType { + KeyType::try_from(self.0[0]).unwrap() + } +} + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum KeyType { + DirectChatEvent = 1, + GroupChatEvent = 2, + ChannelEvent = 3, + DirectChatThreadEvent = 4, + GroupChatThreadEvent = 5, + ChannelThreadEvent = 6, + GroupMember = 7, + ChannelMember = 8, + CommunityMember = 9, +} + +fn extract_key_type(bytes: &[u8]) -> Option { + bytes.first().and_then(|b| KeyType::try_from(*b).ok()) +} + +impl KeyType { + pub fn is_chat_event_key(&self) -> bool { + matches!( + self, + KeyType::DirectChatEvent + | KeyType::GroupChatEvent + | KeyType::ChannelEvent + | KeyType::DirectChatThreadEvent + | KeyType::GroupChatThreadEvent + | KeyType::ChannelThreadEvent + ) + } + + pub fn is_member_key(&self) -> bool { + matches!(self, KeyType::GroupMember | KeyType::ChannelMember | KeyType::CommunityMember) + } +} + +impl TryFrom for KeyType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(KeyType::DirectChatEvent), + 2 => Ok(KeyType::GroupChatEvent), + 3 => Ok(KeyType::ChannelEvent), + 4 => Ok(KeyType::DirectChatThreadEvent), + 5 => Ok(KeyType::GroupChatThreadEvent), + 6 => Ok(KeyType::ChannelThreadEvent), + 7 => Ok(KeyType::GroupMember), + 8 => Ok(KeyType::ChannelMember), + 9 => Ok(KeyType::CommunityMember), + _ => Err(()), + } + } +} +#[cfg(test)] +mod tests { + use super::*; + use crate::{ChatEventKey, ChatEventKeyPrefix}; + use rand::{thread_rng, Rng, RngCore}; + use types::{ChannelId, Chat, EventIndex, MessageIndex}; + + #[test] + fn direct_chat_event_key_e2e() { + for thread in [false, true] { + for _ in 0..100 { + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = Principal::from_slice(&user_id_bytes); + let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); + let prefix = ChatEventKeyPrefix::new_from_direct_chat(user_id.into(), thread_root_message_index); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = Key::from(prefix.create_key(event_index)); + let event_key = ChatEventKey::try_from(key.clone()).unwrap(); + + assert_eq!( + *event_key.0.first().unwrap(), + if thread { KeyType::DirectChatThreadEvent } else { KeyType::DirectChatEvent } as u8 + ); + assert_eq!(event_key.0.len(), if thread { 20 } else { 16 }); + assert!(event_key.matches_prefix(&prefix)); + assert!(event_key.matches_chat(&Chat::Direct(user_id.into()))); + assert_eq!(event_key.event_index(), event_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } + } + + #[test] + fn group_chat_event_key_e2e() { + for thread in [false, true] { + for _ in 0..100 { + let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); + let prefix = ChatEventKeyPrefix::new_from_group_chat(thread_root_message_index); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = Key::from(prefix.create_key(event_index)); + let event_key = ChatEventKey::try_from(key.clone()).unwrap(); + + assert_eq!( + *event_key.0.first().unwrap(), + if thread { KeyType::GroupChatThreadEvent } else { KeyType::GroupChatEvent } as u8 + ); + assert_eq!(event_key.0.len(), if thread { 9 } else { 5 }); + assert!(event_key.matches_prefix(&prefix)); + assert!(event_key.matches_chat(&Chat::Group(Principal::anonymous().into()))); + assert_eq!(event_key.event_index(), event_index); + assert_eq!(event_key.thread_root_message_index(), thread_root_message_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } + } + + #[test] + fn channel_event_key_e2e() { + for thread in [false, true] { + for _ in 0..100 { + let channel_id = ChannelId::from(thread_rng().next_u32()); + let thread_root_message_index = thread.then(|| MessageIndex::from(thread_rng().next_u32())); + let prefix = ChatEventKeyPrefix::new_from_channel(channel_id, thread_root_message_index); + let event_index = EventIndex::from(thread_rng().next_u32()); + let key = Key::from(prefix.create_key(event_index)); + let event_key = ChatEventKey::try_from(key.clone()).unwrap(); + + assert_eq!( + *event_key.0.first().unwrap(), + if thread { KeyType::ChannelThreadEvent } else { KeyType::ChannelEvent } as u8 + ); + assert_eq!(event_key.0.len(), if thread { 13 } else { 9 }); + assert!(event_key.matches_prefix(&prefix)); + assert!(event_key.matches_chat(&Chat::Channel(Principal::anonymous().into(), channel_id))); + assert_eq!(event_key.event_index(), event_index); + + let serialized = msgpack::serialize_then_unwrap(&event_key); + assert_eq!(serialized.len(), event_key.0.len() + 2); + let deserialized: ChatEventKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, event_key); + assert_eq!(deserialized.0, key.0); + } + } + } + + #[test] + fn group_chat_member_key_e2e() { + for _ in 0..100 { + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); + let prefix = MemberKeyPrefix::new_from_group(); + let key = Key::from(prefix.create_key(user_id)); + let member_key = MemberKey::try_from(key.clone()).unwrap(); + + assert_eq!(*member_key.0.first().unwrap(), KeyType::GroupMember as u8); + assert_eq!(member_key.0.len(), 11); + assert!(member_key.matches_prefix(&prefix)); + assert_eq!(member_key.user_id(), user_id); + + let serialized = msgpack::serialize_then_unwrap(&member_key); + assert_eq!(serialized.len(), member_key.0.len() + 2); + let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, member_key); + assert_eq!(deserialized.0, key.0); + } + } + + #[test] + fn channel_member_key_e2e() { + for _ in 0..100 { + let channel_id = ChannelId::from(thread_rng().next_u32()); + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); + let prefix = MemberKeyPrefix::new_from_channel(channel_id); + let key = Key::from(prefix.create_key(user_id)); + let member_key = MemberKey::try_from(key.clone()).unwrap(); + + assert_eq!(*member_key.0.first().unwrap(), KeyType::ChannelMember as u8); + assert_eq!(member_key.0.len(), 15); + assert!(member_key.matches_prefix(&prefix)); + assert_eq!(member_key.user_id(), user_id); + + let serialized = msgpack::serialize_then_unwrap(&member_key); + assert_eq!(serialized.len(), member_key.0.len() + 2); + let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, member_key); + assert_eq!(deserialized.0, key.0); + } + } + + #[test] + fn community_member_key_e2e() { + for _ in 0..100 { + let user_id_bytes: [u8; 10] = thread_rng().gen(); + let user_id = UserId::from(Principal::from_slice(&user_id_bytes)); + let prefix = MemberKeyPrefix::new_from_community(); + let key = Key::from(prefix.create_key(user_id)); + let member_key = MemberKey::try_from(key.clone()).unwrap(); + + assert_eq!(*member_key.0.first().unwrap(), KeyType::CommunityMember as u8); + assert_eq!(member_key.0.len(), 11); + assert!(member_key.matches_prefix(&prefix)); + assert_eq!(member_key.user_id(), user_id); + + let serialized = msgpack::serialize_then_unwrap(&member_key); + assert_eq!(serialized.len(), member_key.0.len() + 2); + let deserialized: MemberKey = msgpack::deserialize_then_unwrap(&serialized); + assert_eq!(deserialized, member_key); + assert_eq!(deserialized.0, key.0); + } + } +} diff --git a/backend/libraries/stable_memory_map/src/lib.rs b/backend/libraries/stable_memory_map/src/lib.rs index 7da7897442..904e50c60d 100644 --- a/backend/libraries/stable_memory_map/src/lib.rs +++ b/backend/libraries/stable_memory_map/src/lib.rs @@ -6,10 +6,14 @@ use ic_stable_structures::memory_manager::VirtualMemory; use ic_stable_structures::{DefaultMemoryImpl, StableBTreeMap}; use std::cell::RefCell; +mod key; + +pub use key::*; + pub type Memory = VirtualMemory; struct StableMemoryMap { - map: StableBTreeMap, Vec, Memory>, + map: StableBTreeMap, Memory>, } thread_local! { @@ -22,24 +26,24 @@ pub fn init(memory: Memory) { })); } -pub fn with_map, Vec, Memory>) -> R, R>(f: F) -> R { +pub fn with_map, Memory>) -> R, R>(f: F) -> R { MAP.with_borrow(|m| f(&m.as_ref().unwrap().map)) } -pub fn with_map_mut, Vec, Memory>) -> R, R>(f: F) -> R { +pub fn with_map_mut, Memory>) -> R, R>(f: F) -> R { MAP.with_borrow_mut(|m| f(&mut m.as_mut().unwrap().map)) } -pub fn garbage_collect(prefix: Vec) -> Result { - assert!(!prefix.is_empty()); +pub fn garbage_collect(prefix: KeyPrefix) -> Result { + // assert!(!prefix.is_empty()); let mut total_count = 0; with_map_mut(|m| { // If < 2B instructions have been used so far, delete another 100 keys, or exit if complete while ic_cdk::api::instruction_counter() < 2_000_000_000 { let keys: Vec<_> = m - .range(prefix.clone()..) - .take_while(|(k, _)| k.starts_with(&prefix)) + .range(Key::from(prefix.clone())..) + .take_while(|(k, _)| k.starts_with(prefix.as_slice())) .map(|(k, _)| k) .take(100) .collect(); @@ -57,32 +61,3 @@ pub fn garbage_collect(prefix: Vec) -> Result { Err(total_count) }) } - -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum KeyType { - DirectChatEvent = 1, - GroupChatEvent = 2, - ChannelEvent = 3, - DirectChatThreadEvent = 4, - GroupChatThreadEvent = 5, - ChannelThreadEvent = 6, - ChatMember = 7, - CommunityMember = 8, -} - -impl From for KeyType { - fn from(value: u8) -> Self { - match value { - 1 => KeyType::DirectChatEvent, - 2 => KeyType::GroupChatEvent, - 3 => KeyType::ChannelEvent, - 4 => KeyType::DirectChatThreadEvent, - 5 => KeyType::GroupChatThreadEvent, - 6 => KeyType::ChannelThreadEvent, - 7 => KeyType::ChatMember, - 8 => KeyType::CommunityMember, - _ => unreachable!(), - } - } -} diff --git a/backend/libraries/types/src/chat_id.rs b/backend/libraries/types/src/chat_id.rs index 7285af9f04..03ef4d4c4c 100644 --- a/backend/libraries/types/src/chat_id.rs +++ b/backend/libraries/types/src/chat_id.rs @@ -2,6 +2,7 @@ use crate::{CanisterId, UserId}; use candid::{CandidType, Principal}; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display, Formatter}; +use std::ops::Deref; use ts_export::ts_export; #[ts_export] @@ -38,8 +39,10 @@ impl Display for ChatId { } } -impl AsRef<[u8]> for ChatId { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() +impl Deref for ChatId { + type Target = CanisterId; + + fn deref(&self) -> &Self::Target { + &self.0 } }