Skip to content

Commit

Permalink
feat: establish proof system version handling for onchain prover
Browse files Browse the repository at this point in the history
  • Loading branch information
jac18281828 committed Nov 6, 2024
1 parent 411b7df commit 623765b
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 57 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 37 additions & 11 deletions onchain/bonsol/src/actions/status.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use crate::{
assertions::*,
error::ChannelError,
proof_handling::{output_digest, prepare_inputs, verify_risc0},
proof_handling::{output_digest_v1_0_1, prepare_inputs_v1_0_1, verify_risc0_v1_0_1},
utilities::*,
};

use bonsol_interface::{
bonsol_schema::{root_as_execution_request_v1, ChannelInstruction, ExitCode, StatusV1},
bonsol_schema::{
root_as_execution_request_v1, ChannelInstruction, ExecutionRequestV1, ExitCode, StatusV1,
},
prover_version::{ProverVersion, VERSION_V1_0_1},
util::execution_address_seeds,
};

use solana_program::{
account_info::AccountInfo,
clock::Clock,
Expand Down Expand Up @@ -94,15 +99,7 @@ pub fn process_status_v1<'a>(
er.input_digest()
.map(|x| check_bytes_match(x.bytes(), input_digest, ChannelError::InputsDontMatch));
}
let output_digest = output_digest(input_digest, co, asud);
let proof_inputs = prepare_inputs(
er.image_id().unwrap(),
exed,
output_digest.as_ref(),
st.exit_code_system(),
st.exit_code_user(),
)?;
let verified = verify_risc0(proof, &proof_inputs)?;
let verified = verify_with_prover(input_digest, co, asud, er, exed, st, proof)?;
let tip = er.tip();
if verified {
let callback_program_set =
Expand Down Expand Up @@ -185,3 +182,32 @@ pub fn process_status_v1<'a>(
}
Ok(())
}

fn verify_with_prover(
input_digest: &[u8],
co: &[u8],
asud: &[u8],
er: ExecutionRequestV1,
exed: &[u8],
st: StatusV1,
proof: &[u8; 256],
) -> Result<bool, ProgramError> {
let prover_version =
ProverVersion::try_from(er.prover_version()).unwrap_or(ProverVersion::default());

let verified = match prover_version {
VERSION_V1_0_1 => {
let output_digest = output_digest_v1_0_1(input_digest, co, asud);
let proof_inputs = prepare_inputs_v1_0_1(
er.image_id().unwrap(),
exed,
output_digest.as_ref(),
st.exit_code_system(),
st.exit_code_user(),
)?;
verify_risc0_v1_0_1(proof, &proof_inputs)?
}
_ => false,
};
Ok(verified)
}
2 changes: 2 additions & 0 deletions onchain/bonsol/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub enum ChannelError {
InvalidExecutionId,
#[error("Invalid Execution Account Owner")]
InvalidExecutionAccountOwner,
#[error("Unexpected Proof System")]
UnexpectedProofSystem,
}

impl From<ChannelError> for ProgramError {
Expand Down
8 changes: 6 additions & 2 deletions onchain/bonsol/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
#![allow(clippy::arithmetic_side_effects)]
#![cfg_attr(not(test), forbid(unsafe_code))]
use solana_program::declare_id;

pub mod actions;
mod assertions;
pub mod error;
pub mod program;
pub mod proof_handling;
pub mod prover;
pub mod utilities;

mod assertions;
mod verifying_key;

use solana_program::declare_id;

declare_id!("BoNsHRcyLLNdtnoDf8hiCNZpyehMC4FDMxs6NTxFi3ew");

#[cfg(not(feature = "no-entrypoint"))]
Expand Down
200 changes: 158 additions & 42 deletions onchain/bonsol/src/proof_handling.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
use crate::{error::ChannelError, verifying_key::VERIFYINGKEY};
use std::ops::Neg;

use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use groth16_solana::groth16::Groth16Verifier;
use hex_literal::hex;
use solana_program::hash::hashv;
use std::ops::Neg;
type G1 = ark_bn254::g1::G1Affine;

fn sized_range<const N: usize>(slice: &[u8]) -> Result<[u8; N], ChannelError> {
slice
.try_into()
.map_err(|_| ChannelError::InvalidInstruction)
}
use crate::{
error::ChannelError,
prover::{Groth16Prover, PROVER_CONSTANTS_V1_0_1},
verifying_key::VERIFYINGKEY,
};

fn change_endianness(bytes: &[u8]) -> Vec<u8> {
let mut vec = Vec::new();
for b in bytes.chunks(32) {
for byte in b.iter().rev() {
vec.push(*byte);
}
type G1 = ark_bn254::g1::G1Affine;

pub fn verify_risc0(
proof: &[u8],
inputs: &[u8],
groth16_prover: Groth16Prover,
) -> Result<bool, ChannelError> {
match groth16_prover {
Groth16Prover::V1_0_1 => verify_risc0_v1_0_1(proof, inputs),
_ => Err(ChannelError::UnexpectedProofSystem),
}
vec
}

pub fn verify_risc0(proof: &[u8], inputs: &[u8]) -> Result<bool, ChannelError> {
pub fn verify_risc0_v1_0_1(proof: &[u8], inputs: &[u8]) -> Result<bool, ChannelError> {
let ace: Vec<u8> = change_endianness(&*[&proof[0..64], &[0u8][..]].concat());
let proof_a: G1 = G1::deserialize_with_mode(&*ace, Compress::No, Validate::No).unwrap();

Expand Down Expand Up @@ -58,32 +59,23 @@ pub fn verify_risc0(proof: &[u8], inputs: &[u8]) -> Result<bool, ChannelError> {
.map_err(|_| ChannelError::ProofVerificationFailed)
}

const CONTROL_ROOT: [u8; 32] =
hex!("a516a057c9fbf5629106300934d48e0e775d4230e41e503347cad96fcbde7e2e");
const BN254_CONTROL_ID_BYTES: [u8; 32] =
hex!("0eb6febcf06c5df079111be116f79bd8c7e85dc9448776ef9a59aaf2624ab551");
const OUTPUT_HASH: [u8; 32] =
hex!("77eafeb366a78b47747de0d7bb176284085ff5564887009a5be63da32d3559d4");
const RECIEPT_CLAIM_HASH: [u8; 32] =
hex!("cb1fefcd1f2d9a64975cbbbf6e161e2914434b0cbb9960b84df5d717e86b48af");

pub fn output_digest(
pub fn output_digest_v1_0_1(
input_digest: &[u8],
committed_outputs: &[u8],
assumption_digest: &[u8],
) -> [u8; 32] {
let jbytes = [input_digest, committed_outputs].concat(); // bad copy here
let journal = hashv(&[jbytes.as_slice()]);
hashv(&[
OUTPUT_HASH.as_ref(),
PROVER_CONSTANTS_V1_0_1.output_hash.as_ref(),
journal.as_ref(),
assumption_digest,
&2u16.to_le_bytes(),
])
.to_bytes()
}

pub fn prepare_inputs(
pub fn prepare_inputs_v1_0_1(
image_id: &str,
execution_digest: &[u8],
output_digest: &[u8],
Expand All @@ -92,7 +84,7 @@ pub fn prepare_inputs(
) -> Result<Vec<u8>, ChannelError> {
let imgbytes = hex::decode(image_id).map_err(|_| ChannelError::InvalidFieldElement)?;
let mut digest = hashv(&[
RECIEPT_CLAIM_HASH.as_ref(),
PROVER_CONSTANTS_V1_0_1.receipt_claim_hash.as_ref(),
&[0u8; 32],
&imgbytes,
execution_digest,
Expand All @@ -102,32 +94,156 @@ pub fn prepare_inputs(
&4u16.to_le_bytes(),
])
.to_bytes();
let (c0, c1) =
split_digest(&mut CONTROL_ROOT.clone()).map_err(|_| ChannelError::InvalidFieldElement)?;
let (c0, c1) = split_digest_reversed(&mut PROVER_CONSTANTS_V1_0_1.control_root.clone())
.map_err(|_| ChannelError::InvalidFieldElement)?;
let (half1_bytes, half2_bytes) =
split_digest(&mut digest).map_err(|_| ChannelError::InvalidFieldElement)?;
split_digest_reversed(&mut digest).map_err(|_| ChannelError::InvalidFieldElement)?;
let inputs = [
c0,
c1,
half1_bytes.try_into().unwrap(),
half2_bytes.try_into().unwrap(),
BN254_CONTROL_ID_BYTES,
PROVER_CONSTANTS_V1_0_1.bn254_control_id_bytes,
]
.concat();
Ok(inputs)
}

pub fn split_digest(d: &mut [u8]) -> Result<([u8; 32], [u8; 32]), ChannelError> {
/**
* Reverse and split a digest into two halves
* The first half is the left half of the digest
* The second half is the right half of the digest
*
* @param d: The digest to split
* @return A tuple containing the left and right halves of the digest
*/
pub fn split_digest_reversed_256(d: &mut [u8]) -> Result<([u8; 32], [u8; 32]), ChannelError> {
split_digest_reversed::<32>(d)
}

fn split_digest_reversed<const N: usize>(d: &mut [u8]) -> Result<([u8; N], [u8; N]), ChannelError> {
if d.len() != N {
return Err(ChannelError::UnexpectedProofSystem);
}
d.reverse();
let (a, b) = d.split_at(16);
let af = to_fixed_array(a.to_vec());
let bf = to_fixed_array(b.to_vec());
let split_index = (N + 1) / 2;
let (a, b) = d.split_at(split_index);
let af = to_fixed_array(a);
let bf = to_fixed_array(b);
Ok((bf, af))
}

fn to_fixed_array(input: Vec<u8>) -> [u8; 32] {
let mut fixed_array = [0u8; 32];
let start = core::cmp::max(32, input.len()) - core::cmp::min(32, input.len());
fixed_array[start..].copy_from_slice(&input[input.len().saturating_sub(32)..]);
fn to_fixed_array<const N: usize>(input: &[u8]) -> [u8; N] {
let mut fixed_array = [0u8; N];
if input.len() >= N {
// Copy the last N bytes of input into fixed_array
fixed_array.copy_from_slice(&input[input.len() - N..]);
} else {
// Copy input into the end of fixed_array
let start = N - input.len();
fixed_array[start..].copy_from_slice(input);
}
fixed_array
}

fn sized_range<const N: usize>(slice: &[u8]) -> Result<[u8; N], ChannelError> {
slice
.try_into()
.map_err(|_| ChannelError::InvalidInstruction)
}

fn change_endianness(bytes: &[u8]) -> Vec<u8> {
let mut vec = Vec::with_capacity(bytes.len());
let chunk_size = 32;

for chunk in bytes.chunks(chunk_size) {
// Reverse the chunk and extend the vector
vec.extend(chunk.iter().rev());
}

vec
}
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_change_endianness() {
let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8];
let expected = [8u8, 7, 6, 5, 4, 3, 2, 1];
assert_eq!(change_endianness(&bytes), expected);
}

#[test]
fn test_change_endianness_odd() {
let bytes = [1u8, 2, 3, 4, 5, 6, 7];
let expected = [7u8, 6, 5, 4, 3, 2, 1];
assert_eq!(change_endianness(&bytes), expected);
}

#[test]
fn test_change_endianness_double_word() {
let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let expected = [16u8, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
assert_eq!(change_endianness(&bytes), expected);
}

#[test]
fn test_split_digest() {
let mut digest = [1u8; 32];
digest[0] = 103;
let (a, b) = split_digest_reversed(&mut digest).unwrap();
let expect_digest_right = to_fixed_array::<32>(&[1u8; 16]);
let mut expect_digest_left = expect_digest_right.clone();
expect_digest_left[31] = 103;
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_split_digest_odd() {
let mut digest = [1u8; 31];
digest[0] = 103;
let (a, b) = split_digest_reversed(&mut digest).unwrap();
let expect_digest_right = to_fixed_array::<31>(&[1u8; 16]);
let mut expect_digest_left = to_fixed_array::<31>(&[1u8; 15]);
expect_digest_left[30] = 103;
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_split_digest_16() {
let digest = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let (a, b) = split_digest_reversed::<16>(&mut digest.to_vec()).unwrap();
let expect_digest_left = to_fixed_array::<16>(&[7, 6, 5, 4, 3, 2, 1, 0]);
let expect_digest_right = to_fixed_array::<16>(&[15, 14, 13, 12, 11, 10, 9, 8]);
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_split_digest_8() {
let digest = [0, 1, 2, 3, 4, 5, 6, 7];
let (a, b) = split_digest_reversed::<8>(&mut digest.to_vec()).unwrap();
let expect_digest_left = to_fixed_array::<8>(&[3, 2, 1, 0]);
let expect_digest_right = to_fixed_array::<8>(&[7, 6, 5, 4]);
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_invalid_digest_wrong_size() {
let mut d1 = [1u8; 31];
assert!(split_digest_reversed_256(&mut d1).is_err());
let mut d2 = [1u8; 33];
assert!(split_digest_reversed_256(&mut d2).is_err());
}

#[test]
fn test_sized_range() {
let slice = [1u8; 32];
let expected = [1u8; 32];
assert_eq!(sized_range::<32>(&slice).unwrap(), expected);
}
}
Loading

0 comments on commit 623765b

Please sign in to comment.