Skip to content

Commit

Permalink
perf: move cycle tracker to Err in compile_one, criterion benchm…
Browse files Browse the repository at this point in the history
…arks (#1369)
  • Loading branch information
tqn authored Aug 22, 2024
1 parent d3a73a9 commit 8d5c103
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 51 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions crates/recursion/compiler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ p3-challenger = { workspace = true }
p3-dft = { workspace = true }
p3-merkle-tree = { workspace = true }
rand = "0.8.5"
criterion = { version = "0.5.1", features = ["html_reports"] }

[[bench]]
name = "circuit"
harness = false
85 changes: 85 additions & 0 deletions crates/recursion/compiler/benches/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::time::Duration;

use criterion::*;
use p3_symmetric::Permutation;
use rand::{rngs::StdRng, Rng, SeedableRng};

use sp1_recursion_compiler::{
asm::{AsmBuilder, AsmConfig},
circuit::*,
ir::{DslIr, TracedVec},
prelude::Felt,
};
use sp1_recursion_core_v2::chips::poseidon2_wide::WIDTH;
use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, inner_perm, StarkGenericConfig};

type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
type EF = <SC as StarkGenericConfig>::Challenge;
type C = AsmConfig<F, EF>;

fn poseidon_program() -> TracedVec<DslIr<C>> {
let mut builder = AsmBuilder::<F, EF>::default();
let mut rng = StdRng::seed_from_u64(0xCAFEDA7E)
.sample_iter::<[F; WIDTH], _>(rand::distributions::Standard);
for _ in 0..100 {
let input_1: [F; WIDTH] = rng.next().unwrap();
let output_1 = inner_perm().permute(input_1);

let input_1_felts = input_1.map(|x| builder.eval(x));
let output_1_felts = builder.poseidon2_permute_v2(input_1_felts);
let expected: [Felt<_>; WIDTH] = output_1.map(|x| builder.eval(x));
for (lhs, rhs) in output_1_felts.into_iter().zip(expected) {
builder.assert_felt_eq(lhs, rhs);
}
}
builder.operations
}

#[allow(dead_code)]
fn compile_one(c: &mut Criterion) {
let input = {
let mut ops = poseidon_program().vec;
ops.truncate(100);
ops
};

c.bench_with_input(
BenchmarkId::new("compile_one", format!("{} instructions", input.len())),
&input,
|b, operations| {
let mut compiler = AsmCompiler::<AsmConfig<F, EF>>::default();
b.iter(|| {
for instr in operations.iter().cloned() {
compiler.compile_one(std::hint::black_box(instr), drop);
}
compiler.next_addr = Default::default();
compiler.virtual_to_physical.clear();
compiler.consts.clear();
compiler.addr_to_mult.clear();
})
},
);
}

fn compile(c: &mut Criterion) {
let input = poseidon_program();

c.bench_with_input(
BenchmarkId::new("compile", format!("{} instructions", input.vec.len())),
&input,
|b, operations| {
let mut compiler = AsmCompiler::<AsmConfig<F, EF>>::default();
b.iter(|| {
compiler.compile(operations.clone());
})
},
);
}

criterion_group! {
name = benches;
config = Criterion::default().measurement_time(Duration::from_secs(60));
targets = compile
}
criterion_main!(benches);
87 changes: 36 additions & 51 deletions crates/recursion/compiler/src/circuit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,13 @@ where
self.consts.entry(imm).or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::zero())).0
}

fn mem_write_const(&mut self, dst: impl Reg<C>, src: Imm<C::F, C::EF>) -> CompileOneItem<C::F> {
fn mem_write_const(&mut self, dst: impl Reg<C>, src: Imm<C::F, C::EF>) -> Instruction<C::F> {
Instruction::Mem(MemInstr {
addrs: MemIo { inner: dst.write(self) },
vals: MemIo { inner: src.as_block() },
mult: C::F::zero(),
kind: MemAccessKind::Write,
})
.into()
}

fn base_alu(
Expand All @@ -167,13 +166,12 @@ where
dst: impl Reg<C>,
lhs: impl Reg<C>,
rhs: impl Reg<C>,
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
Instruction::BaseAlu(BaseAluInstr {
opcode,
mult: C::F::zero(),
addrs: BaseAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) },
})
.into()
}

fn ext_alu(
Expand All @@ -182,20 +180,19 @@ where
dst: impl Reg<C>,
lhs: impl Reg<C>,
rhs: impl Reg<C>,
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
Instruction::ExtAlu(ExtAluInstr {
opcode,
mult: C::F::zero(),
addrs: ExtAluIo { out: dst.write(self), in1: lhs.read(self), in2: rhs.read(self) },
})
.into()
}

fn base_assert_eq(
&mut self,
lhs: impl Reg<C>,
rhs: impl Reg<C>,
mut f: impl FnMut(CompileOneItem<C::F>),
mut f: impl FnMut(Instruction<C::F>),
) {
use BaseAluOpcode::*;
let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
Expand All @@ -207,7 +204,7 @@ where
&mut self,
lhs: impl Reg<C>,
rhs: impl Reg<C>,
mut f: impl FnMut(CompileOneItem<C::F>),
mut f: impl FnMut(Instruction<C::F>),
) {
use BaseAluOpcode::*;
let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
Expand All @@ -220,7 +217,7 @@ where
&mut self,
lhs: impl Reg<C>,
rhs: impl Reg<C>,
mut f: impl FnMut(CompileOneItem<C::F>),
mut f: impl FnMut(Instruction<C::F>),
) {
use ExtAluOpcode::*;
let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
Expand All @@ -233,7 +230,7 @@ where
&mut self,
lhs: impl Reg<C>,
rhs: impl Reg<C>,
mut f: impl FnMut(CompileOneItem<C::F>),
mut f: impl FnMut(Instruction<C::F>),
) {
use ExtAluOpcode::*;
let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr));
Expand All @@ -246,23 +243,22 @@ where
&mut self,
dst: [impl Reg<C>; WIDTH],
src: [impl Reg<C>; WIDTH],
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
Instruction::Poseidon2(Box::new(Poseidon2Instr {
addrs: Poseidon2Io {
input: src.map(|r| r.read(self)),
output: dst.map(|r| r.write(self)),
},
mults: [C::F::zero(); WIDTH],
}))
.into()
}

fn exp_reverse_bits(
&mut self,
dst: impl Reg<C>,
base: impl Reg<C>,
exp: impl IntoIterator<Item = impl Reg<C>>,
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
addrs: ExpReverseBitsIo {
result: dst.write(self),
Expand All @@ -271,19 +267,17 @@ where
},
mult: C::F::zero(),
})
.into()
}

fn hint_bit_decomposition(
&mut self,
value: impl Reg<C>,
output: impl IntoIterator<Item = impl Reg<C>>,
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
Instruction::HintBits(HintBitsInstr {
output_addrs_mults: output.into_iter().map(|r| (r.write(self), C::F::zero())).collect(),
input_addr: value.read_ghost(self),
})
.into()
}

fn fri_fold(
Expand All @@ -298,7 +292,7 @@ where
alpha_pow_input,
ro_input,
}: CircuitV2FriFoldInput<C>,
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
Instruction::FriFold(Box::new(FriFoldInstr {
// Calculate before moving the vecs.
alpha_pow_mults: vec![C::F::zero(); alpha_pow_output.len()],
Expand All @@ -315,13 +309,12 @@ where
ro_output: ro_output.into_iter().map(|e| e.write(self)).collect(),
},
}))
.into()
}

fn commit_public_values(
&mut self,
public_values: &RecursionPublicValues<Felt<C::F>>,
) -> CompileOneItem<C::F> {
) -> Instruction<C::F> {
public_values.digest.iter().for_each(|x| {
let _ = x.read(self);
});
Expand All @@ -338,38 +331,33 @@ where
Instruction::CommitPublicValues(Box::new(CommitPublicValuesInstr {
pv_addrs: *public_values_a,
}))
.into()
}

fn print_f(&mut self, addr: impl Reg<C>) -> CompileOneItem<C::F> {
fn print_f(&mut self, addr: impl Reg<C>) -> Instruction<C::F> {
Instruction::Print(PrintInstr {
field_elt_type: FieldEltType::Base,
addr: addr.read_ghost(self),
})
.into()
}

fn print_e(&mut self, addr: impl Reg<C>) -> CompileOneItem<C::F> {
fn print_e(&mut self, addr: impl Reg<C>) -> Instruction<C::F> {
Instruction::Print(PrintInstr {
field_elt_type: FieldEltType::Extension,
addr: addr.read_ghost(self),
})
.into()
}

fn ext2felts(&mut self, felts: [impl Reg<C>; D], ext: impl Reg<C>) -> CompileOneItem<C::F> {
fn ext2felts(&mut self, felts: [impl Reg<C>; D], ext: impl Reg<C>) -> Instruction<C::F> {
Instruction::HintExt2Felts(HintExt2FeltsInstr {
output_addrs_mults: felts.map(|r| (r.write(self), C::F::zero())),
input_addr: ext.read_ghost(self),
})
.into()
}

fn hint(&mut self, output: &[impl Reg<C>]) -> CompileOneItem<C::F> {
fn hint(&mut self, output: &[impl Reg<C>]) -> Instruction<C::F> {
Instruction::Hint(HintInstr {
output_addrs_mults: output.iter().map(|r| (r.write(self), C::F::zero())).collect(),
})
.into()
}

/// Compiles one instruction, passing one or more instructions to `consumer`.
Expand All @@ -379,7 +367,7 @@ where
pub fn compile_one<F>(
&mut self,
ir_instr: DslIr<C>,
mut consumer: impl FnMut(Result<CompileOneItem<C::F>, DslIr<C>>),
mut consumer: impl FnMut(Result<Instruction<C::F>, CompileOneErr<C>>),
) where
F: PrimeField + TwoAdicField,
C: Config<N = F, F = F> + Debug,
Expand Down Expand Up @@ -476,9 +464,11 @@ where
DslIr::CircuitV2HintFelts(output) => f(self.hint(&output)),
DslIr::CircuitV2HintExts(output) => f(self.hint(&output)),
DslIr::CircuitExt2Felt(felts, ext) => f(self.ext2felts(felts, ext)),
DslIr::CycleTrackerV2Enter(name) => f(CompileOneItem::CycleTrackerEnter(name)),
DslIr::CycleTrackerV2Exit => f(CompileOneItem::CycleTrackerExit),
instr => consumer(Err(instr)),
DslIr::CycleTrackerV2Enter(name) => {
consumer(Err(CompileOneErr::CycleTrackerEnter(name)))
}
DslIr::CycleTrackerV2Exit => consumer(Err(CompileOneErr::CycleTrackerExit)),
instr => consumer(Err(CompileOneErr::Unsupported(instr))),
}
}

Expand All @@ -494,25 +484,25 @@ where
// Compile each IR instruction into a list of ASM instructions, then combine them.
// This step also counts the number of times each address is read from.
let (mut instrs, traces) = tracing::debug_span!("compile_one loop").in_scope(|| {
let mut instrs = vec![];
let mut instrs = Vec::with_capacity(operations.vec.len());
let mut traces = vec![];
if debug_mode {
let mut span_builder =
SpanBuilder::<_, &'static str>::new("cycle_tracker".to_string());
for (ir_instr, trace) in operations {
self.compile_one(ir_instr, |item| match item {
Ok(CompileOneItem::Instr(instr)) => {
self.compile_one(ir_instr, &mut |item| match item {
Ok(instr) => {
span_builder.item(instr_name(&instr));
instrs.push(instr);
traces.push(trace.clone());
}
Ok(CompileOneItem::CycleTrackerEnter(name)) => {
Err(CompileOneErr::CycleTrackerEnter(name)) => {
span_builder.enter(name);
}
Ok(CompileOneItem::CycleTrackerExit) => {
Err(CompileOneErr::CycleTrackerExit) => {
span_builder.exit().unwrap();
}
Err(instr) => {
Err(CompileOneErr::Unsupported(instr)) => {
panic!("unsupported instruction: {instr:?}\nbacktrace: {:?}", trace)
}
});
Expand All @@ -523,10 +513,12 @@ where
}
} else {
for (ir_instr, trace) in operations {
self.compile_one(ir_instr, |item| match item {
Ok(CompileOneItem::Instr(instr)) => instrs.push(instr),
Ok(_) => (),
Err(instr) => {
self.compile_one(ir_instr, &mut |item| match item {
Ok(instr) => instrs.push(instr),
Err(
CompileOneErr::CycleTrackerEnter(_) | CompileOneErr::CycleTrackerExit,
) => (),
Err(CompileOneErr::Unsupported(instr)) => {
panic!("unsupported instruction: {instr:?}\nbacktrace: {:?}", trace)
}
});
Expand Down Expand Up @@ -649,20 +641,13 @@ const fn instr_name<F>(instr: &Instruction<F>) -> &'static str {
}
}

/// Instruction or annotation. Result of compiling one `DslIr` item.
#[derive(Debug, Clone)]
pub enum CompileOneItem<F> {
Instr(Instruction<F>),
pub enum CompileOneErr<C: Config> {
Unsupported(DslIr<C>),
CycleTrackerEnter(String),
CycleTrackerExit,
}

impl<F> From<Instruction<F>> for CompileOneItem<F> {
fn from(value: Instruction<F>) -> Self {
CompileOneItem::Instr(value)
}
}

/// Immediate (i.e. constant) field element.
///
/// Required to distinguish a base and extension field element at the type level,
Expand Down

0 comments on commit 8d5c103

Please sign in to comment.