Skip to content

Commit

Permalink
Fix: add load functions for state
Browse files Browse the repository at this point in the history
  • Loading branch information
RainRaydium committed Sep 13, 2024
1 parent eebc18f commit a2163dc
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 15 deletions.
11 changes: 11 additions & 0 deletions dex/src/critbit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ impl Slab {
slab
}

#[inline]
pub fn new_check(bytes: &[u8]) -> &Self {
let len_without_header = bytes.len().checked_sub(SLAB_HEADER_LEN).unwrap();
let slop = len_without_header % size_of::<AnyNode>();
let truncated_len = bytes.len() - slop;
let bytes = &bytes[..truncated_len];
let slab: &Self = unsafe { &*(bytes as *const [u8] as *const Slab) };
slab.check_size_align(); // check alignment
slab
}

#[inline]
pub fn assert_minimum_capacity(&self, capacity: u32) -> DexResult {
if self.nodes().len() <= (capacity as usize) * 2 {
Expand Down
189 changes: 174 additions & 15 deletions dex/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#![cfg_attr(not(feature = "program"), allow(unused))]
use num_enum::TryFromPrimitive;
use std::{
cell::RefMut, convert::identity, convert::TryInto, mem::size_of, num::NonZeroU64, ops::Deref,
cell::{Ref, RefMut},
convert::identity,
convert::TryInto,
mem::size_of,
num::NonZeroU64,
ops::Deref,
ops::DerefMut,
};

use arrayref::{array_ref, array_refs, mut_array_refs};

use bytemuck::{
bytes_of, bytes_of_mut, cast, cast_slice, cast_slice_mut, from_bytes_mut, try_cast_mut,
try_cast_slice_mut, try_from_bytes_mut, Pod, Zeroable,
try_cast_slice, try_cast_slice_mut, try_from_bytes, try_from_bytes_mut, Pod, Zeroable,
};
use enumflags2::BitFlags;
use num_traits::FromPrimitive;
Expand Down Expand Up @@ -68,6 +73,8 @@ pub enum AccountFlag {
pub enum Market<'a> {
V1(RefMut<'a, MarketState>),
V2(RefMut<'a, MarketStateV2>),
V1Ref(Ref<'a, MarketState>),
V2Ref(Ref<'a, MarketStateV2>),
}

impl<'a> Deref for Market<'a> {
Expand All @@ -77,6 +84,8 @@ impl<'a> Deref for Market<'a> {
match self {
Market::V1(v1) => v1.deref(),
Market::V2(v2) => v2.deref(),
Market::V1Ref(v1_ref) => v1_ref.deref(),
Market::V2Ref(v2_ref) => v2_ref.deref(),
}
}
}
Expand All @@ -86,6 +95,7 @@ impl<'a> DerefMut for Market<'a> {
match self {
Market::V1(v1) => v1.deref_mut(),
Market::V2(v2) => v2.deref_mut(),
_ => unreachable!(),
}
}
}
Expand Down Expand Up @@ -114,6 +124,29 @@ impl<'a> Market<'a> {
}
}

#[inline]
pub fn load_checked(
market_account: &'a AccountInfo,
program_id: &Pubkey,
// Allow for the market flag to be set to AccountFlag::Disabled
allow_disabled: bool,
) -> DexResult<Self> {
let flags = Market::account_flags(&market_account.try_borrow_data()?)?;
if flags.intersects(AccountFlag::Permissioned) {
Ok(Market::V2Ref(MarketStateV2::load_checked(
market_account,
program_id,
allow_disabled,
)?))
} else {
Ok(Market::V1Ref(MarketState::load_checked(
market_account,
program_id,
allow_disabled,
)?))
}
}

pub fn account_flags(account_data: &[u8]) -> DexResult<BitFlags<AccountFlag>> {
let start = ACCOUNT_HEAD_PADDING.len();
let end = start + size_of::<AccountFlag>();
Expand All @@ -129,21 +162,23 @@ impl<'a> Market<'a> {

pub fn open_orders_authority(&self) -> Option<&Pubkey> {
match &self {
Market::V1(_) => None,
Market::V1(_) | Market::V1Ref(_) => None,
Market::V2(state) => Some(&state.open_orders_authority),
Market::V2Ref(state) => Some(&state.open_orders_authority),
}
}

pub fn prune_authority(&self) -> Option<&Pubkey> {
match &self {
Market::V1(_) => None,
Market::V1(_) | Market::V1Ref(_) => None,
Market::V2(state) => Some(&state.prune_authority),
Market::V2Ref(state) => Some(&state.prune_authority),
}
}

pub fn consume_events_authority(&self) -> Option<&Pubkey> {
match &self {
Market::V1(_) => None,
Market::V1(_) | Market::V1Ref(_) => None,
Market::V2(state) => {
let flags = BitFlags::from_bits(state.inner.account_flags).unwrap();
if flags.intersects(AccountFlag::CrankAuthorityRequired) {
Expand All @@ -152,6 +187,14 @@ impl<'a> Market<'a> {
None
}
}
Market::V2Ref(state) => {
let flags = BitFlags::from_bits(state.inner.account_flags).unwrap();
if flags.intersects(AccountFlag::CrankAuthorityRequired) {
Some(&state.consume_events_authority)
} else {
None
}
}
}
}

Expand Down Expand Up @@ -251,6 +294,29 @@ impl MarketStateV2 {
Ok(state)
}

#[inline]
pub fn load_checked<'a>(
market_account: &'a AccountInfo,
program_id: &Pubkey,
allow_disabled: bool,
) -> DexResult<Ref<'a, Self>> {
check_assert_eq!(market_account.owner, program_id)?;

let account_data = market_account.try_borrow_data()?;
check_assert!(account_data.len() >= 12)?;
let head = array_ref![account_data, 0, 5];
let tail = array_ref![account_data, account_data.len() - 6, 7];
check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?;
check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?;

let state: Ref<'a, Self> = Ref::map(account_data, |account_data| {
bytemuck::from_bytes(&account_data[5..account_data.len() - 7])
});

state.check_flags(allow_disabled)?;
Ok(state)
}

#[inline]
pub fn check_flags(&self, allow_disabled: bool) -> DexResult {
let flags = BitFlags::from_bits(self.account_flags)
Expand Down Expand Up @@ -369,6 +435,16 @@ fn check_account_padding(data: &mut [u8]) -> DexResult<&mut [[u8; 8]]> {
Ok(try_cast_slice_mut(data).or(check_unreachable!())?)
}

fn check_account_padding_checked(account_data: &[u8]) -> DexResult<&[[u8; 8]]> {
check_assert!(account_data.len() >= 12)?;
let head = array_ref![account_data, 0, 5];
let tail = array_ref![account_data, account_data.len() - 6, 7];
check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?;
check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?;
let data = &account_data[5..account_data.len() - 7];
Ok(try_cast_slice(&data).or(check_unreachable!())?)
}

fn strip_account_padding(padded_data: &mut [u8], init_allowed: bool) -> DexResult<&mut [[u8; 8]]> {
if init_allowed {
init_account_padding(padded_data)
Expand All @@ -377,6 +453,10 @@ fn strip_account_padding(padded_data: &mut [u8], init_allowed: bool) -> DexResul
}
}

fn strip_account_padding_checked(padded_data: &[u8]) -> DexResult<&[[u8; 8]]> {
check_account_padding_checked(padded_data)
}

pub fn strip_header<'a, H: Pod, D: Pod>(
account: &'a AccountInfo,
init_allowed: bool,
Expand Down Expand Up @@ -415,6 +495,43 @@ pub fn strip_header<'a, H: Pod, D: Pod>(
Ok((header, inner))
}

pub fn strip_header_checked<'a, H: Pod, D: Pod>(
account: &'a AccountInfo,
) -> DexResult<(Ref<'a, H>, Ref<'a, [D]>)> {
let mut result = Ok(());
let (header, inner): (Ref<'a, [H]>, Ref<'a, [D]>) =
Ref::map_split(account.try_borrow_data()?, |padded_data| {
let dummy_value: (&[H], &[D]) = (&[], &[]);
let padded_data: &[u8] = *padded_data;
let u64_data = match strip_account_padding_checked(padded_data) {
Ok(u64_data) => u64_data,
Err(e) => {
result = Err(e);
return dummy_value;
}
};

let data: &[u8] = cast_slice(u64_data);
let (header_bytes, inner_bytes) = data.split_at(size_of::<H>());
let header: &H;
let inner: &[D];

header = match try_from_bytes(header_bytes) {
Ok(h) => h,
Err(_e) => {
result = Err(assertion_error!().into());
return dummy_value;
}
};
inner = remove_slop(inner_bytes);

(std::slice::from_ref(header), inner)
});
result?;
let header = Ref::map(header, |s| s.first().unwrap_or_else(|| unreachable!()));
Ok((header, inner))
}

impl MarketState {
#[inline]
pub fn load<'a>(
Expand All @@ -438,6 +555,29 @@ impl MarketState {
Ok(state)
}

#[inline]
pub fn load_checked<'a>(
market_account: &'a AccountInfo,
program_id: &Pubkey,
allow_disabled: bool,
) -> DexResult<Ref<'a, Self>> {
check_assert_eq!(market_account.owner, program_id)?;

let account_data = market_account.try_borrow_data()?;
check_assert!(account_data.len() >= 12)?;
let head = array_ref![account_data, 0, 5];
let tail = array_ref![account_data, account_data.len() - 6, 7];
check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?;
check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?;

let state: Ref<'a, Self> = Ref::map(account_data, |account_data| {
bytemuck::from_bytes(&account_data[5..account_data.len() - 7])
});

state.check_flags(allow_disabled)?;
Ok(state)
}

#[inline]
pub fn check_flags(&self, allow_disabled: bool) -> DexResult {
let flags = BitFlags::from_bits(self.account_flags)
Expand Down Expand Up @@ -465,6 +605,15 @@ impl MarketState {
Ok(RefMut::map(buf, Slab::new))
}

pub fn load_bids_checked<'a>(&self, bids: &'a AccountInfo) -> DexResult<Ref<'a, Slab>> {
check_assert_eq!(&bids.key.to_aligned_bytes(), &identity(self.bids))
.map_err(|_| DexErrorCode::WrongBidsAccount)?;
let (header, buf) = strip_header_checked::<OrderBookStateHeader, u8>(bids)?;
let flags = BitFlags::from_bits(header.account_flags).unwrap();
check_assert_eq!(&flags, &(AccountFlag::Initialized | AccountFlag::Bids))?;
Ok(Ref::map(buf, Slab::new_check))
}

pub fn load_asks_mut<'a>(&self, asks: &'a AccountInfo) -> DexResult<RefMut<'a, Slab>> {
check_assert_eq!(&asks.key.to_aligned_bytes(), &identity(self.asks))
.map_err(|_| DexErrorCode::WrongAsksAccount)?;
Expand All @@ -474,6 +623,15 @@ impl MarketState {
Ok(RefMut::map(buf, Slab::new))
}

pub fn load_asks_checked<'a>(&self, asks: &'a AccountInfo) -> DexResult<Ref<'a, Slab>> {
check_assert_eq!(&asks.key.to_aligned_bytes(), &identity(self.asks))
.map_err(|_| DexErrorCode::WrongAsksAccount)?;
let (header, buf) = strip_header_checked::<OrderBookStateHeader, u8>(asks)?;
let flags = BitFlags::from_bits(header.account_flags).unwrap();
check_assert_eq!(&flags, &(AccountFlag::Initialized | AccountFlag::Asks))?;
Ok(Ref::map(buf, Slab::new_check))
}

fn load_request_queue_mut<'a>(&self, queue: &'a AccountInfo) -> DexResult<RequestQueue<'a>> {
check_assert_eq!(&queue.key.to_aligned_bytes(), &identity(self.req_q))
.map_err(|_| DexErrorCode::WrongRequestQueueAccount)?;
Expand Down Expand Up @@ -632,22 +790,23 @@ impl OpenOrders {
}

#[inline]
pub fn load<'a>(
pub fn load_checked<'a>(
orders_account: &'a AccountInfo,
market_account: Option<&AccountInfo>,
owner_account: Option<&AccountInfo>,
program_id: &Pubkey,
) -> DexResult<RefMut<'a, Self>> {
) -> DexResult<Ref<'a, Self>> {
check_assert_eq!(orders_account.owner, program_id)?;
let mut account_data: RefMut<'a, [u8]>;
let state: RefMut<'a, Self>;

account_data = RefMut::map(orders_account.try_borrow_mut_data()?, |data| *data);
check_account_padding(&mut account_data)?;
state = RefMut::map(account_data, |data| {
from_bytes_mut(cast_slice_mut(
check_account_padding(data).unwrap_or_else(|_| unreachable!()),
))
let account_data = orders_account.try_borrow_data()?;
check_assert!(account_data.len() >= 12)?;
let head = array_ref![account_data, 0, 5];
let tail = array_ref![account_data, account_data.len() - 6, 7];
check_assert_eq!(head, ACCOUNT_HEAD_PADDING)?;
check_assert_eq!(tail, ACCOUNT_TAIL_PADDING)?;

let state: Ref<'a, Self> = Ref::map(account_data, |account_data| {
bytemuck::from_bytes(&account_data[5..account_data.len() - 7])
});

state.check_flags()?;
Expand Down

0 comments on commit a2163dc

Please sign in to comment.