diff --git a/dex/src/critbit.rs b/dex/src/critbit.rs index d6709eb..51e9904 100644 --- a/dex/src/critbit.rs +++ b/dex/src/critbit.rs @@ -302,6 +302,17 @@ impl Slab { slab } + #[inline] + pub fn new_check(bytes: &[u8]) -> &Self { + let len_without_header = bytes.len().checked_sub(SLAB_HEADER_LEN).unwrap(); + let slop = len_without_header % size_of::(); + let truncated_len = bytes.len() - slop; + let bytes = &bytes[..truncated_len]; + let slab: &Self = unsafe { &*(bytes as *const [u8] as *const Slab) }; + slab.check_size_align(); // check alignment + slab + } + #[inline] pub fn assert_minimum_capacity(&self, capacity: u32) -> DexResult { if self.nodes().len() <= (capacity as usize) * 2 { diff --git a/dex/src/state.rs b/dex/src/state.rs index ec10fcb..fff54cd 100644 --- a/dex/src/state.rs +++ b/dex/src/state.rs @@ -1,7 +1,12 @@ #![cfg_attr(not(feature = "program"), allow(unused))] use num_enum::TryFromPrimitive; use std::{ - cell::RefMut, convert::identity, convert::TryInto, mem::size_of, num::NonZeroU64, ops::Deref, + cell::{Ref, RefMut}, + convert::identity, + convert::TryInto, + mem::size_of, + num::NonZeroU64, + ops::Deref, ops::DerefMut, }; @@ -9,7 +14,7 @@ use arrayref::{array_ref, array_refs, mut_array_refs}; use bytemuck::{ bytes_of, bytes_of_mut, cast, cast_slice, cast_slice_mut, from_bytes_mut, try_cast_mut, - try_cast_slice_mut, try_from_bytes_mut, Pod, Zeroable, + try_cast_slice, try_cast_slice_mut, try_from_bytes, try_from_bytes_mut, Pod, Zeroable, }; use enumflags2::BitFlags; use num_traits::FromPrimitive; @@ -68,6 +73,8 @@ pub enum AccountFlag { pub enum Market<'a> { V1(RefMut<'a, MarketState>), V2(RefMut<'a, MarketStateV2>), + V1Ref(Ref<'a, MarketState>), + V2Ref(Ref<'a, MarketStateV2>), } impl<'a> Deref for Market<'a> { @@ -77,6 +84,8 @@ impl<'a> Deref for Market<'a> { match self { Market::V1(v1) => v1.deref(), Market::V2(v2) => v2.deref(), + Market::V1Ref(v1_ref) => v1_ref.deref(), + Market::V2Ref(v2_ref) => v2_ref.deref(), } } } @@ -86,6 +95,7 @@ impl<'a> DerefMut for Market<'a> { match self { Market::V1(v1) => v1.deref_mut(), Market::V2(v2) => v2.deref_mut(), + _ => unreachable!(), } } } @@ -114,6 +124,29 @@ impl<'a> Market<'a> { } } + #[inline] + pub fn load_checked( + market_account: &'a AccountInfo, + program_id: &Pubkey, + // Allow for the market flag to be set to AccountFlag::Disabled + allow_disabled: bool, + ) -> DexResult { + let flags = Market::account_flags(&market_account.try_borrow_data()?)?; + if flags.intersects(AccountFlag::Permissioned) { + Ok(Market::V2Ref(MarketStateV2::load_checked( + market_account, + program_id, + allow_disabled, + )?)) + } else { + Ok(Market::V1Ref(MarketState::load_checked( + market_account, + program_id, + allow_disabled, + )?)) + } + } + pub fn account_flags(account_data: &[u8]) -> DexResult> { let start = ACCOUNT_HEAD_PADDING.len(); let end = start + size_of::(); @@ -129,21 +162,23 @@ impl<'a> Market<'a> { pub fn open_orders_authority(&self) -> Option<&Pubkey> { match &self { - Market::V1(_) => None, + Market::V1(_) | Market::V1Ref(_) => None, Market::V2(state) => Some(&state.open_orders_authority), + Market::V2Ref(state) => Some(&state.open_orders_authority), } } pub fn prune_authority(&self) -> Option<&Pubkey> { match &self { - Market::V1(_) => None, + Market::V1(_) | Market::V1Ref(_) => None, Market::V2(state) => Some(&state.prune_authority), + Market::V2Ref(state) => Some(&state.prune_authority), } } pub fn consume_events_authority(&self) -> Option<&Pubkey> { match &self { - Market::V1(_) => None, + Market::V1(_) | Market::V1Ref(_) => None, Market::V2(state) => { let flags = BitFlags::from_bits(state.inner.account_flags).unwrap(); if flags.intersects(AccountFlag::CrankAuthorityRequired) { @@ -152,6 +187,14 @@ impl<'a> Market<'a> { None } } + Market::V2Ref(state) => { + let flags = BitFlags::from_bits(state.inner.account_flags).unwrap(); + if flags.intersects(AccountFlag::CrankAuthorityRequired) { + Some(&state.consume_events_authority) + } else { + None + } + } } } @@ -251,6 +294,29 @@ impl MarketStateV2 { Ok(state) } + #[inline] + pub fn load_checked<'a>( + market_account: &'a AccountInfo, + program_id: &Pubkey, + allow_disabled: bool, + ) -> DexResult> { + check_assert_eq!(market_account.owner, program_id)?; + + let account_data = market_account.try_borrow_data()?; + check_assert!(account_data.len() >= 12)?; + let head = array_ref![account_data, 0, 5]; + let tail = array_ref![account_data, account_data.len() - 6, 7]; + check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?; + check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?; + + let state: Ref<'a, Self> = Ref::map(account_data, |account_data| { + bytemuck::from_bytes(&account_data[5..account_data.len() - 7]) + }); + + state.check_flags(allow_disabled)?; + Ok(state) + } + #[inline] pub fn check_flags(&self, allow_disabled: bool) -> DexResult { let flags = BitFlags::from_bits(self.account_flags) @@ -369,6 +435,16 @@ fn check_account_padding(data: &mut [u8]) -> DexResult<&mut [[u8; 8]]> { Ok(try_cast_slice_mut(data).or(check_unreachable!())?) } +fn check_account_padding_checked(account_data: &[u8]) -> DexResult<&[[u8; 8]]> { + check_assert!(account_data.len() >= 12)?; + let head = array_ref![account_data, 0, 5]; + let tail = array_ref![account_data, account_data.len() - 6, 7]; + check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?; + check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?; + let data = &account_data[5..account_data.len() - 7]; + Ok(try_cast_slice(&data).or(check_unreachable!())?) +} + fn strip_account_padding(padded_data: &mut [u8], init_allowed: bool) -> DexResult<&mut [[u8; 8]]> { if init_allowed { init_account_padding(padded_data) @@ -377,6 +453,10 @@ fn strip_account_padding(padded_data: &mut [u8], init_allowed: bool) -> DexResul } } +fn strip_account_padding_checked(padded_data: &[u8]) -> DexResult<&[[u8; 8]]> { + check_account_padding_checked(padded_data) +} + pub fn strip_header<'a, H: Pod, D: Pod>( account: &'a AccountInfo, init_allowed: bool, @@ -415,6 +495,43 @@ pub fn strip_header<'a, H: Pod, D: Pod>( Ok((header, inner)) } +pub fn strip_header_checked<'a, H: Pod, D: Pod>( + account: &'a AccountInfo, +) -> DexResult<(Ref<'a, H>, Ref<'a, [D]>)> { + let mut result = Ok(()); + let (header, inner): (Ref<'a, [H]>, Ref<'a, [D]>) = + Ref::map_split(account.try_borrow_data()?, |padded_data| { + let dummy_value: (&[H], &[D]) = (&[], &[]); + let padded_data: &[u8] = *padded_data; + let u64_data = match strip_account_padding_checked(padded_data) { + Ok(u64_data) => u64_data, + Err(e) => { + result = Err(e); + return dummy_value; + } + }; + + let data: &[u8] = cast_slice(u64_data); + let (header_bytes, inner_bytes) = data.split_at(size_of::()); + let header: &H; + let inner: &[D]; + + header = match try_from_bytes(header_bytes) { + Ok(h) => h, + Err(_e) => { + result = Err(assertion_error!().into()); + return dummy_value; + } + }; + inner = remove_slop(inner_bytes); + + (std::slice::from_ref(header), inner) + }); + result?; + let header = Ref::map(header, |s| s.first().unwrap_or_else(|| unreachable!())); + Ok((header, inner)) +} + impl MarketState { #[inline] pub fn load<'a>( @@ -438,6 +555,29 @@ impl MarketState { Ok(state) } + #[inline] + pub fn load_checked<'a>( + market_account: &'a AccountInfo, + program_id: &Pubkey, + allow_disabled: bool, + ) -> DexResult> { + check_assert_eq!(market_account.owner, program_id)?; + + let account_data = market_account.try_borrow_data()?; + check_assert!(account_data.len() >= 12)?; + let head = array_ref![account_data, 0, 5]; + let tail = array_ref![account_data, account_data.len() - 6, 7]; + check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?; + check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?; + + let state: Ref<'a, Self> = Ref::map(account_data, |account_data| { + bytemuck::from_bytes(&account_data[5..account_data.len() - 7]) + }); + + state.check_flags(allow_disabled)?; + Ok(state) + } + #[inline] pub fn check_flags(&self, allow_disabled: bool) -> DexResult { let flags = BitFlags::from_bits(self.account_flags) @@ -465,6 +605,15 @@ impl MarketState { Ok(RefMut::map(buf, Slab::new)) } + pub fn load_bids_checked<'a>(&self, bids: &'a AccountInfo) -> DexResult> { + check_assert_eq!(&bids.key.to_aligned_bytes(), &identity(self.bids)) + .map_err(|_| DexErrorCode::WrongBidsAccount)?; + let (header, buf) = strip_header_checked::(bids)?; + let flags = BitFlags::from_bits(header.account_flags).unwrap(); + check_assert_eq!(&flags, &(AccountFlag::Initialized | AccountFlag::Bids))?; + Ok(Ref::map(buf, Slab::new_check)) + } + pub fn load_asks_mut<'a>(&self, asks: &'a AccountInfo) -> DexResult> { check_assert_eq!(&asks.key.to_aligned_bytes(), &identity(self.asks)) .map_err(|_| DexErrorCode::WrongAsksAccount)?; @@ -474,6 +623,15 @@ impl MarketState { Ok(RefMut::map(buf, Slab::new)) } + pub fn load_asks_checked<'a>(&self, asks: &'a AccountInfo) -> DexResult> { + check_assert_eq!(&asks.key.to_aligned_bytes(), &identity(self.asks)) + .map_err(|_| DexErrorCode::WrongAsksAccount)?; + let (header, buf) = strip_header_checked::(asks)?; + let flags = BitFlags::from_bits(header.account_flags).unwrap(); + check_assert_eq!(&flags, &(AccountFlag::Initialized | AccountFlag::Asks))?; + Ok(Ref::map(buf, Slab::new_check)) + } + fn load_request_queue_mut<'a>(&self, queue: &'a AccountInfo) -> DexResult> { check_assert_eq!(&queue.key.to_aligned_bytes(), &identity(self.req_q)) .map_err(|_| DexErrorCode::WrongRequestQueueAccount)?; @@ -632,22 +790,23 @@ impl OpenOrders { } #[inline] - pub fn load<'a>( + pub fn load_checked<'a>( orders_account: &'a AccountInfo, market_account: Option<&AccountInfo>, owner_account: Option<&AccountInfo>, program_id: &Pubkey, - ) -> DexResult> { + ) -> DexResult> { check_assert_eq!(orders_account.owner, program_id)?; - let mut account_data: RefMut<'a, [u8]>; - let state: RefMut<'a, Self>; - account_data = RefMut::map(orders_account.try_borrow_mut_data()?, |data| *data); - check_account_padding(&mut account_data)?; - state = RefMut::map(account_data, |data| { - from_bytes_mut(cast_slice_mut( - check_account_padding(data).unwrap_or_else(|_| unreachable!()), - )) + let account_data = orders_account.try_borrow_data()?; + check_assert!(account_data.len() >= 12)?; + let head = array_ref![account_data, 0, 5]; + let tail = array_ref![account_data, account_data.len() - 6, 7]; + check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?; + check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?; + + let state: Ref<'a, Self> = Ref::map(account_data, |account_data| { + bytemuck::from_bytes(&account_data[5..account_data.len() - 7]) }); state.check_flags()?;