diff --git a/Cargo.lock b/Cargo.lock index e8992204cf..e84c8ea422 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5326,6 +5326,7 @@ name = "sp1-recursion-compiler" version = "1.1.1" dependencies = [ "backtrace", + "criterion", "itertools 0.13.0", "p3-air", "p3-baby-bear", diff --git a/crates/recursion/compiler/Cargo.toml b/crates/recursion/compiler/Cargo.toml index fd80194fd4..88fa4ed3a0 100644 --- a/crates/recursion/compiler/Cargo.toml +++ b/crates/recursion/compiler/Cargo.toml @@ -40,3 +40,8 @@ p3-challenger = { workspace = true } p3-dft = { workspace = true } p3-merkle-tree = { workspace = true } rand = "0.8.5" +criterion = { version = "0.5.1", features = ["html_reports"] } + +[[bench]] +name = "circuit" +harness = false diff --git a/crates/recursion/compiler/benches/circuit.rs b/crates/recursion/compiler/benches/circuit.rs new file mode 100644 index 0000000000..42d2e94ec5 --- /dev/null +++ b/crates/recursion/compiler/benches/circuit.rs @@ -0,0 +1,85 @@ +use std::time::Duration; + +use criterion::*; +use p3_symmetric::Permutation; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +use sp1_recursion_compiler::{ + asm::{AsmBuilder, AsmConfig}, + circuit::*, + ir::{DslIr, TracedVec}, + prelude::Felt, +}; +use sp1_recursion_core_v2::chips::poseidon2_wide::WIDTH; +use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, inner_perm, StarkGenericConfig}; + +type SC = BabyBearPoseidon2; +type F = ::Val; +type EF = ::Challenge; +type C = AsmConfig; + +fn poseidon_program() -> TracedVec> { + let mut builder = AsmBuilder::::default(); + let mut rng = StdRng::seed_from_u64(0xCAFEDA7E) + .sample_iter::<[F; WIDTH], _>(rand::distributions::Standard); + for _ in 0..100 { + let input_1: [F; WIDTH] = rng.next().unwrap(); + let output_1 = inner_perm().permute(input_1); + + let input_1_felts = input_1.map(|x| builder.eval(x)); + let output_1_felts = builder.poseidon2_permute_v2(input_1_felts); + let expected: [Felt<_>; WIDTH] = output_1.map(|x| builder.eval(x)); + for (lhs, rhs) in output_1_felts.into_iter().zip(expected) { + builder.assert_felt_eq(lhs, rhs); + } + } + builder.operations +} + +#[allow(dead_code)] +fn compile_one(c: &mut Criterion) { + let input = { + let mut ops = poseidon_program().vec; + ops.truncate(100); + ops + }; + + c.bench_with_input( + BenchmarkId::new("compile_one", format!("{} instructions", input.len())), + &input, + |b, operations| { + let mut compiler = AsmCompiler::>::default(); + b.iter(|| { + for instr in operations.iter().cloned() { + compiler.compile_one(std::hint::black_box(instr), drop); + } + compiler.next_addr = Default::default(); + compiler.virtual_to_physical.clear(); + compiler.consts.clear(); + compiler.addr_to_mult.clear(); + }) + }, + ); +} + +fn compile(c: &mut Criterion) { + let input = poseidon_program(); + + c.bench_with_input( + BenchmarkId::new("compile", format!("{} instructions", input.vec.len())), + &input, + |b, operations| { + let mut compiler = AsmCompiler::>::default(); + b.iter(|| { + compiler.compile(operations.clone()); + }) + }, + ); +} + +criterion_group! { + name = benches; + config = Criterion::default().measurement_time(Duration::from_secs(60)); + targets = compile +} +criterion_main!(benches); diff --git a/crates/recursion/compiler/src/circuit/compiler.rs b/crates/recursion/compiler/src/circuit/compiler.rs index 4ff97e0ec5..e537bb0e35 100644 --- a/crates/recursion/compiler/src/circuit/compiler.rs +++ b/crates/recursion/compiler/src/circuit/compiler.rs @@ -151,14 +151,13 @@ where self.consts.entry(imm).or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::zero())).0 } - fn mem_write_const(&mut self, dst: impl Reg, src: Imm) -> CompileOneItem { + fn mem_write_const(&mut self, dst: impl Reg, src: Imm) -> Instruction { Instruction::Mem(MemInstr { addrs: MemIo { inner: dst.write(self) }, vals: MemIo { inner: src.as_block() }, mult: C::F::zero(), kind: MemAccessKind::Write, }) - .into() } fn base_alu( @@ -167,13 +166,12 @@ where dst: impl Reg, lhs: impl Reg, rhs: impl Reg, - ) -> CompileOneItem { + ) -> Instruction { Instruction::BaseAlu(BaseAluInstr { opcode, mult: C::F::zero(), addrs: BaseAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) }, }) - .into() } fn ext_alu( @@ -182,20 +180,19 @@ where dst: impl Reg, lhs: impl Reg, rhs: impl Reg, - ) -> CompileOneItem { + ) -> Instruction { Instruction::ExtAlu(ExtAluInstr { opcode, mult: C::F::zero(), addrs: ExtAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) }, }) - .into() } fn base_assert_eq( &mut self, lhs: impl Reg, rhs: impl Reg, - mut f: impl FnMut(CompileOneItem), + mut f: impl FnMut(Instruction), ) { use BaseAluOpcode::*; let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); @@ -207,7 +204,7 @@ where &mut self, lhs: impl Reg, rhs: impl Reg, - mut f: impl FnMut(CompileOneItem), + mut f: impl FnMut(Instruction), ) { use BaseAluOpcode::*; let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); @@ -220,7 +217,7 @@ where &mut self, lhs: impl Reg, rhs: impl Reg, - mut f: impl FnMut(CompileOneItem), + mut f: impl FnMut(Instruction), ) { use ExtAluOpcode::*; let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); @@ -233,7 +230,7 @@ where &mut self, lhs: impl Reg, rhs: impl Reg, - mut f: impl FnMut(CompileOneItem), + mut f: impl FnMut(Instruction), ) { use ExtAluOpcode::*; let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); @@ -246,7 +243,7 @@ where &mut self, dst: [impl Reg; WIDTH], src: [impl Reg; WIDTH], - ) -> CompileOneItem { + ) -> Instruction { Instruction::Poseidon2(Box::new(Poseidon2Instr { addrs: Poseidon2Io { input: src.map(|r| r.read(self)), @@ -254,7 +251,6 @@ where }, mults: [C::F::zero(); WIDTH], })) - .into() } fn exp_reverse_bits( @@ -262,7 +258,7 @@ where dst: impl Reg, base: impl Reg, exp: impl IntoIterator>, - ) -> CompileOneItem { + ) -> Instruction { Instruction::ExpReverseBitsLen(ExpReverseBitsInstr { addrs: ExpReverseBitsIo { result: dst.write(self), @@ -271,19 +267,17 @@ where }, mult: C::F::zero(), }) - .into() } fn hint_bit_decomposition( &mut self, value: impl Reg, output: impl IntoIterator>, - ) -> CompileOneItem { + ) -> Instruction { Instruction::HintBits(HintBitsInstr { output_addrs_mults: output.into_iter().map(|r| (r.write(self), C::F::zero())).collect(), input_addr: value.read_ghost(self), }) - .into() } fn fri_fold( @@ -298,7 +292,7 @@ where alpha_pow_input, ro_input, }: CircuitV2FriFoldInput, - ) -> CompileOneItem { + ) -> Instruction { Instruction::FriFold(Box::new(FriFoldInstr { // Calculate before moving the vecs. alpha_pow_mults: vec![C::F::zero(); alpha_pow_output.len()], @@ -315,13 +309,12 @@ where ro_output: ro_output.into_iter().map(|e| e.write(self)).collect(), }, })) - .into() } fn commit_public_values( &mut self, public_values: &RecursionPublicValues>, - ) -> CompileOneItem { + ) -> Instruction { public_values.digest.iter().for_each(|x| { let _ = x.read(self); }); @@ -338,38 +331,33 @@ where Instruction::CommitPublicValues(Box::new(CommitPublicValuesInstr { pv_addrs: *public_values_a, })) - .into() } - fn print_f(&mut self, addr: impl Reg) -> CompileOneItem { + fn print_f(&mut self, addr: impl Reg) -> Instruction { Instruction::Print(PrintInstr { field_elt_type: FieldEltType::Base, addr: addr.read_ghost(self), }) - .into() } - fn print_e(&mut self, addr: impl Reg) -> CompileOneItem { + fn print_e(&mut self, addr: impl Reg) -> Instruction { Instruction::Print(PrintInstr { field_elt_type: FieldEltType::Extension, addr: addr.read_ghost(self), }) - .into() } - fn ext2felts(&mut self, felts: [impl Reg; D], ext: impl Reg) -> CompileOneItem { + fn ext2felts(&mut self, felts: [impl Reg; D], ext: impl Reg) -> Instruction { Instruction::HintExt2Felts(HintExt2FeltsInstr { output_addrs_mults: felts.map(|r| (r.write(self), C::F::zero())), input_addr: ext.read_ghost(self), }) - .into() } - fn hint(&mut self, output: &[impl Reg]) -> CompileOneItem { + fn hint(&mut self, output: &[impl Reg]) -> Instruction { Instruction::Hint(HintInstr { output_addrs_mults: output.iter().map(|r| (r.write(self), C::F::zero())).collect(), }) - .into() } /// Compiles one instruction, passing one or more instructions to `consumer`. @@ -379,7 +367,7 @@ where pub fn compile_one( &mut self, ir_instr: DslIr, - mut consumer: impl FnMut(Result, DslIr>), + mut consumer: impl FnMut(Result, CompileOneErr>), ) where F: PrimeField + TwoAdicField, C: Config + Debug, @@ -476,9 +464,11 @@ where DslIr::CircuitV2HintFelts(output) => f(self.hint(&output)), DslIr::CircuitV2HintExts(output) => f(self.hint(&output)), DslIr::CircuitExt2Felt(felts, ext) => f(self.ext2felts(felts, ext)), - DslIr::CycleTrackerV2Enter(name) => f(CompileOneItem::CycleTrackerEnter(name)), - DslIr::CycleTrackerV2Exit => f(CompileOneItem::CycleTrackerExit), - instr => consumer(Err(instr)), + DslIr::CycleTrackerV2Enter(name) => { + consumer(Err(CompileOneErr::CycleTrackerEnter(name))) + } + DslIr::CycleTrackerV2Exit => consumer(Err(CompileOneErr::CycleTrackerExit)), + instr => consumer(Err(CompileOneErr::Unsupported(instr))), } } @@ -494,25 +484,25 @@ where // Compile each IR instruction into a list of ASM instructions, then combine them. // This step also counts the number of times each address is read from. let (mut instrs, traces) = tracing::debug_span!("compile_one loop").in_scope(|| { - let mut instrs = vec![]; + let mut instrs = Vec::with_capacity(operations.vec.len()); let mut traces = vec![]; if debug_mode { let mut span_builder = SpanBuilder::<_, &'static str>::new("cycle_tracker".to_string()); for (ir_instr, trace) in operations { - self.compile_one(ir_instr, |item| match item { - Ok(CompileOneItem::Instr(instr)) => { + self.compile_one(ir_instr, &mut |item| match item { + Ok(instr) => { span_builder.item(instr_name(&instr)); instrs.push(instr); traces.push(trace.clone()); } - Ok(CompileOneItem::CycleTrackerEnter(name)) => { + Err(CompileOneErr::CycleTrackerEnter(name)) => { span_builder.enter(name); } - Ok(CompileOneItem::CycleTrackerExit) => { + Err(CompileOneErr::CycleTrackerExit) => { span_builder.exit().unwrap(); } - Err(instr) => { + Err(CompileOneErr::Unsupported(instr)) => { panic!("unsupported instruction: {instr:?}\nbacktrace: {:?}", trace) } }); @@ -523,10 +513,12 @@ where } } else { for (ir_instr, trace) in operations { - self.compile_one(ir_instr, |item| match item { - Ok(CompileOneItem::Instr(instr)) => instrs.push(instr), - Ok(_) => (), - Err(instr) => { + self.compile_one(ir_instr, &mut |item| match item { + Ok(instr) => instrs.push(instr), + Err( + CompileOneErr::CycleTrackerEnter(_) | CompileOneErr::CycleTrackerExit, + ) => (), + Err(CompileOneErr::Unsupported(instr)) => { panic!("unsupported instruction: {instr:?}\nbacktrace: {:?}", trace) } }); @@ -649,20 +641,13 @@ const fn instr_name(instr: &Instruction) -> &'static str { } } -/// Instruction or annotation. Result of compiling one `DslIr` item. #[derive(Debug, Clone)] -pub enum CompileOneItem { - Instr(Instruction), +pub enum CompileOneErr { + Unsupported(DslIr), CycleTrackerEnter(String), CycleTrackerExit, } -impl From> for CompileOneItem { - fn from(value: Instruction) -> Self { - CompileOneItem::Instr(value) - } -} - /// Immediate (i.e. constant) field element. /// /// Required to distinguish a base and extension field element at the type level,