From bb2fe64d81f0007028fa332d0a35f8c2f944f882 Mon Sep 17 00:00:00 2001 From: linuskendall Date: Mon, 15 Jan 2024 08:33:41 +0000 Subject: [PATCH] rpc: add solend patch (gPA filter) Adds a patch that allows more complex queries for gPA. --- rpc-client-api/src/filter.rs | 231 +++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) diff --git a/rpc-client-api/src/filter.rs b/rpc-client-api/src/filter.rs index 398bd9807d..aedac65ca5 100644 --- a/rpc-client-api/src/filter.rs +++ b/rpc-client-api/src/filter.rs @@ -17,6 +17,7 @@ pub enum RpcFilterType { DataSize(u64), Memcmp(Memcmp), TokenAccountState, + ValueCmp(ValueCmp), } impl RpcFilterType { @@ -75,6 +76,7 @@ impl RpcFilterType { } } } + RpcFilterType::ValueCmp(_) => Ok(()), RpcFilterType::TokenAccountState => Ok(()), } } @@ -83,6 +85,9 @@ impl RpcFilterType { match self { RpcFilterType::DataSize(size) => account.data().len() as u64 == *size, RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()), + RpcFilterType::ValueCmp(compare) => { + compare.values_match(account.data()).unwrap_or(false) + } RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()), } } @@ -108,6 +113,8 @@ pub enum RpcFilterError { Base58DecodeError(#[from] bs58::decode::Error), #[error("base64 decode error")] Base64DecodeError(#[from] base64::DecodeError), + #[error("invalid filter")] + InvalidFilter, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -220,6 +227,178 @@ impl Memcmp { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ValueCmp { + pub left: Operand, + comparator: Comparator, + pub right: Operand, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Operand { + Mem { + offset: usize, + value_type: ValueType, + }, + Constant(String), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum ValueType { + U8, + U16, + U32, + U64, + U128, +} + +enum WrappedValueType { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), +} + +impl ValueCmp { + fn parse_mem_into_value_type( + o: &Operand, + data: &[u8], + ) -> Result { + match o { + Operand::Mem { offset, value_type } => match value_type { + ValueType::U8 => { + if *offset >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + + Ok(WrappedValueType::U8(data[*offset])) + } + ValueType::U16 => { + if *offset + 1 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U16(u16::from_le_bytes( + data[*offset..*offset + 2].try_into().unwrap(), + ))) + } + ValueType::U32 => { + if *offset + 3 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U32(u32::from_le_bytes( + data[*offset..*offset + 4].try_into().unwrap(), + ))) + } + ValueType::U64 => { + if *offset + 7 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U64(u64::from_le_bytes( + data[*offset..*offset + 8].try_into().unwrap(), + ))) + } + ValueType::U128 => { + if *offset + 15 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U128(u128::from_le_bytes( + data[*offset..*offset + 16].try_into().unwrap(), + ))) + } + }, + _ => Err(RpcFilterError::InvalidFilter), + } + } + + pub fn values_match(&self, data: &[u8]) -> Result { + match (&self.left, &self.right) { + (left @ Operand::Mem { .. }, right @ Operand::Mem { .. }) => { + let left = Self::parse_mem_into_value_type(left, data)?; + let right = Self::parse_mem_into_value_type(right, data)?; + + match (left, right) { + (WrappedValueType::U8(left), WrappedValueType::U8(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U16(left), WrappedValueType::U16(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U32(left), WrappedValueType::U32(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U64(left), WrappedValueType::U64(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U128(left), WrappedValueType::U128(right)) => { + Ok(self.comparator.compare(left, right)) + } + _ => Err(RpcFilterError::InvalidFilter), + } + } + (left @ Operand::Mem { .. }, Operand::Constant(constant)) => { + match Self::parse_mem_into_value_type(left, data)? { + WrappedValueType::U8(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U16(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U32(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U64(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U128(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + } + } + _ => Err(RpcFilterError::InvalidFilter), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Comparator { + Eq = 0, + Ne, + Gt, + Ge, + Lt, + Le, +} + +impl Comparator { + // write a generic function to compare two values + pub fn compare(&self, left: T, right: T) -> bool { + match self { + Comparator::Eq => left == right, + Comparator::Ne => left != right, + Comparator::Gt => left > right, + Comparator::Ge => left >= right, + Comparator::Lt => left < right, + Comparator::Le => left <= right, + } + } +} + // Internal struct to hold Memcmp filter data as either encoded String or raw Bytes #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -395,6 +574,58 @@ mod tests { .bytes_match(&data)); } + #[test] + fn test_values_match() { + // test all the ValueCmp cases + let data = vec![1, 2, 3, 4, 5]; + + let filter = ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8, + }, + comparator: Comparator::Eq, + right: Operand::Constant("2".to_string()), + }; + + assert!(ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8 + }, + comparator: Comparator::Eq, + right: Operand::Constant("2".to_string()) + } + .values_match(&data) + .unwrap()); + + assert!(ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8 + }, + comparator: Comparator::Lt, + right: Operand::Constant("3".to_string()) + } + .values_match(&data) + .unwrap()); + + assert!(ValueCmp { + left: Operand::Mem { + offset: 0, + value_type: ValueType::U32 + }, + comparator: Comparator::Eq, + right: Operand::Constant("67305985".to_string()) + } + .values_match(&data) + .unwrap()); + + // serialize + let s = serde_json::to_string(&filter).unwrap(); + println!("{}", s); + } + #[test] fn test_verify_memcmp() { let base58_bytes = "\