Skip to content

Commit

Permalink
Hand-written bus witness generation
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Dec 4, 2024
1 parent a7d9d7d commit ce15e57
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 11 deletions.
7 changes: 7 additions & 0 deletions backend/src/mock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl<F: FieldElement> Backend<F> for MockBackend<F> {
})
.collect::<BTreeMap<_, _>>();

let start = std::time::Instant::now();
let machines = self
.machine_to_pil
// TODO: We should probably iterate in parallel, because Machine::try_new might generate
Expand All @@ -121,6 +122,12 @@ impl<F: FieldElement> Backend<F> for MockBackend<F> {
})
.map(|machine| (machine.machine_name.clone(), machine))
.collect::<BTreeMap<_, _>>();
if !challenges.is_empty() {
log::info!(
"Generating later-stage witnesses took {:.2}s",
start.elapsed().as_secs_f32()
);
}

let is_ok = machines.values().all(|machine| {
!PolynomialConstraintChecker::new(machine, &challenges)
Expand Down
215 changes: 215 additions & 0 deletions executor/src/witgen/bus_accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
use std::collections::BTreeMap;

use powdr_ast::analyzed::{Analyzed, Identity};
use powdr_number::FieldElement;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

use crate::witgen::evaluators::expression_evaluator::ExpressionEvaluator;

use super::evaluators::expression_evaluator::OwnedTraceValues;

pub fn generate_bus_accumulators<T: FieldElement>(
pil: &Analyzed<T>,
witness_columns: &[(String, Vec<T>)],
fixed_columns: Vec<(String, &[T])>,
challenges: BTreeMap<u64, T>,
) -> Vec<(String, Vec<T>)> {
let bus_interactions = pil
.identities
.iter()
.filter_map(|identity| match identity {
Identity::PhantomBusInteraction(i) => Some(i),
_ => None,
})
.collect::<Vec<_>>();

let trace_values = OwnedTraceValues::new(
pil,
witness_columns.to_vec(),
fixed_columns
.into_iter()
.map(|(name, values)| (name, values.to_vec()))
.collect(),
);
let accumulators = (0..bus_interactions.len())
.into_par_iter()
.map(|i| interaction_columns(pil, i, &trace_values, &challenges))
.collect::<Vec<_>>();

witness_columns
.iter()
.cloned()
.chain(accumulators.into_iter().flatten())
.collect()
}

fn interaction_columns<T: FieldElement>(
pil: &Analyzed<T>,
connection_index: usize,
trace_values: &OwnedTraceValues<T>,
challenges: &BTreeMap<u64, T>,
) -> Vec<(String, Vec<T>)> {
let bus_interactions = pil
.identities
.iter()
.filter_map(|identity| match identity {
Identity::PhantomBusInteraction(i) => Some(i),
_ => None,
})
.collect::<Vec<_>>();
let bus_interaction = bus_interactions[connection_index];

let namespace = pil
.committed_polys_in_source_order()
.next()
.unwrap()
.0
.absolute_name
.split("::")
.next()
.unwrap();
let intermediate_definitions = pil.intermediate_definitions();

let size = trace_values.height();
let mut acc1 = vec![T::zero(); size];
let mut acc2 = vec![T::zero(); size];
let mut acc1_next = vec![T::zero(); size];
let mut acc2_next = vec![T::zero(); size];

let alpha = (challenges[&1], challenges[&2]);
let beta = (challenges[&3], challenges[&4]);

for i in 0..size {
let mut evaluator =
ExpressionEvaluator::new(trace_values.row(i), &intermediate_definitions, challenges);
let current_acc = if i == 0 {
(T::zero(), T::zero())
} else {
(acc1[i - 1], acc2[i - 1])
};

let tuple = bus_interaction
.tuple
.0
.iter()
.map(|r| evaluator.evaluate(r))
.collect::<Vec<_>>();

let fingerprint = sub_ext(beta, fingerprint(&tuple, alpha));
let multiplicity = evaluator.evaluate(&bus_interaction.multiplicity);

/*
add_ext(
current_acc,
mul_ext(m_ext_next, inv_ext(folded_next))
)
*/
let update = add_ext(
current_acc,
mul_ext((multiplicity, T::from(0)), inv_ext(fingerprint)),
);

acc1[i] = update.0;
acc2[i] = update.1;
acc1_next[(i + size - 1) % size] = update.0;
acc2_next[(i + size - 1) % size] = update.1;
}

vec![
(name(namespace, "acc", connection_index * 2), acc1),
(name(namespace, "acc", connection_index * 2 + 1), acc2),
(name(namespace, "acc_next", connection_index * 2), acc1_next),
(
name(namespace, "acc_next", connection_index * 2 + 1),
acc2_next,
),
]
}

fn name(namespace: &str, base: &str, i: usize) -> String {
if i == 0 {
return format!("{namespace}::{base}");
}
format!("{namespace}::{base}_{i}")
}

/*
let<T: Add + FromLiteral + Mul> mul_ext: Fp2<T>, Fp2<T> -> Fp2<T> = |a, b| match (a, b) {
(Fp2::Fp2(a0, a1), Fp2::Fp2(b0, b1)) => Fp2::Fp2(
// Multiplication modulo the polynomial x^2 - 11. We'll use the fact
// that x^2 == 11 (mod x^2 - 11), so:
// (a0 + a1 * x) * (b0 + b1 * x) = a0 * b0 + 11 * a1 * b1 + (a1 * b0 + a0 * b1) * x (mod x^2 - 11)
a0 * b0 + 11 * a1 * b1,
a1 * b0 + a0 * b1
)
};
*/

fn mul_ext<T: FieldElement>(a: (T, T), b: (T, T)) -> (T, T) {
(a.0 * b.0 + a.1 * b.1 * T::from(11), a.1 * b.0 + a.0 * b.1)
}

fn add_ext<T: FieldElement>(a: (T, T), b: (T, T)) -> (T, T) {
(a.0 + b.0, a.1 + b.1)
}

fn sub_ext<T: FieldElement>(a: (T, T), b: (T, T)) -> (T, T) {
(a.0 - b.0, a.1 - b.1)
}

/*
/// Maps [x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$
/// To generate an expression that computes the fingerprint, use `fingerprint_inter` instead.
/// Note that alpha is passed as an expressions, so that it is only evaluated if needed (i.e., if len(expr_array) > 1).
let fingerprint: fe[], Fp2<expr> -> Fp2<fe> = query |expr_array, alpha| {
fingerprint_impl(expr_array, eval_ext(alpha), len(expr_array))
};
let fingerprint_impl: fe[], Fp2<fe>, int -> Fp2<fe> = query |expr_array, alpha, l| if l == 1 {
// Base case
from_base(expr_array[0])
} else {
// Recursively compute the fingerprint as fingerprint(expr_array[:-1], alpha) * alpha + expr_array[-1]
let intermediate_fingerprint = fingerprint_impl(expr_array, alpha, l - 1);
add_ext(mul_ext(alpha, intermediate_fingerprint), from_base(expr_array[l - 1]))
};
*/

fn fingerprint<T: FieldElement>(expr_array: &[T], alpha: (T, T)) -> (T, T) {
fingerprint_impl(expr_array, alpha, expr_array.len())
}

fn fingerprint_impl<T: FieldElement>(expr_array: &[T], alpha: (T, T), l: usize) -> (T, T) {
if l == 1 {
return (expr_array[0], T::zero());
}

let intermediate_fingerprint = fingerprint_impl(expr_array, alpha, l - 1);
add_ext(
mul_ext(alpha, intermediate_fingerprint),
(expr_array[l - 1], T::zero()),
)
}

/*
/// Extension field inversion
let inv_ext: Fp2<fe> -> Fp2<fe> = |a| match a {
// The inverse of (a0, a1) is a point (b0, b1) such that:
// (a0 + a1 * x) (b0 + b1 * x) = 1 (mod x^2 - 11)
// Multiplying out and plugging in x^2 = 11 yields the following system of linear equations:
// a0 * b0 + 11 * a1 * b1 = 1
// a1 * b0 + a0 * b1 = 0
// Solving for (b0, b1) yields:
Fp2::Fp2(a0, a1) => {
let factor = inv_field(11 * a1 * a1 - a0 * a0);
Fp2::Fp2(-a0 * factor, a1 * factor)
}
};
*/
fn inv_ext<T: FieldElement>(a: (T, T)) -> (T, T) {
let factor = T::from(1) / (T::from(11) * a.1 * a.1 - a.0 * a.0);
(-a.0 * factor, a.1 * factor)
}
11 changes: 7 additions & 4 deletions executor/src/witgen/evaluators/expression_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ impl<T> OwnedTraceValues<T> {
.committed_polys_in_source_order()
.chain(pil.constant_polys_in_source_order())
.flat_map(|(symbol, _)| symbol.array_elements())
.map(|(name, poly_id)| {
let column = columns_by_name
.filter_map(|(name, poly_id)| {
columns_by_name
.remove(&name)
.unwrap_or_else(|| panic!("Missing column: {name}"));
(poly_id, column)
.map(|column| (poly_id, column))
})
.collect();
Self { values }
}

pub fn height(&self) -> usize {
self.values.values().next().map(|v| v.len()).unwrap()
}

pub fn row(&self, row: usize) -> RowTraceValues<T> {
RowTraceValues { trace: self, row }
}
Expand Down
49 changes: 42 additions & 7 deletions executor/src/witgen/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::sync::Arc;

use bus_accumulator::generate_bus_accumulators;
use itertools::Itertools;
use machines::machine_extractor::MachineExtractor;
use powdr_ast::analyzed::{
AlgebraicExpression, AlgebraicReference, AlgebraicReferenceThin, Analyzed, DegreeRange,
Expression, FunctionValueDefinition, PolyID, PolynomialType, Symbol, SymbolKind,
Expression, FunctionValueDefinition, Identity, PolyID, PolynomialType, Symbol, SymbolKind,
TypedExpression,
};
use powdr_ast::parsed::visitor::{AllChildren, ExpressionVisitable};
use powdr_ast::parsed::{FunctionKind, LambdaExpression};
use powdr_number::{DegreeType, FieldElement};
use powdr_number::{DegreeType, FieldElement, KnownField};
use std::iter::once;

use crate::constant_evaluator::VariablySizedColumn;
Expand All @@ -27,6 +28,7 @@ use self::machines::profiling::{record_end, record_start, reset_and_print_profil
mod affine_expression;
pub(crate) mod analysis;
mod block_processor;
mod bus_accumulator;
mod data_structures;
mod eval_result;
pub mod evaluators;
Expand Down Expand Up @@ -94,6 +96,25 @@ impl<T: FieldElement> WitgenCallbackContext<T> {
.collect()
}

pub fn select_fixed_columns2(
&self,
pil: &Analyzed<T>,
size: DegreeType,
) -> Vec<(String, &[T])> {
// The provided PIL might only contain a subset of all fixed columns.
let fixed_column_names = pil
.constant_polys_in_source_order()
.flat_map(|(symbol, _)| symbol.array_elements())
.map(|(name, _)| name.clone())
.collect::<BTreeSet<_>>();
// Select the columns in the current PIL and select the right size.
self.fixed_col_values
.iter()
.filter(|(n, _)| fixed_column_names.contains(n))
.map(|(n, v)| (n.clone(), v.get_by_size(size).unwrap()))
.collect()
}

/// Computes the next-stage witness, given the current witness and challenges.
/// All columns in the provided PIL are expected to have the same size.
/// Typically, this function should be called once per machine.
Expand All @@ -105,11 +126,25 @@ impl<T: FieldElement> WitgenCallbackContext<T> {
stage: u8,
) -> Vec<(String, Vec<T>)> {
let size = current_witness.iter().next().unwrap().1.len() as DegreeType;
let fixed_col_values = self.select_fixed_columns(pil, size);
WitnessGenerator::new(pil, &fixed_col_values, &*self.query_callback)
.with_external_witness_values(current_witness)
.with_challenges(stage, challenges)
.generate()

let has_phantom_bus_sends = pil
.identities
.iter()
.any(|identity| matches!(identity, Identity::PhantomBusInteraction(_)));

if has_phantom_bus_sends && T::known_field() == Some(KnownField::GoldilocksField) {
log::debug!("Using hand-written bus witgen.");
let fixed_col_values = self.select_fixed_columns2(pil, size);
assert_eq!(stage, 1);
generate_bus_accumulators(pil, current_witness, fixed_col_values, challenges)
} else {
log::debug!("Using automatic stage-1 witgen.");
let fixed_col_values = self.select_fixed_columns(pil, size);
WitnessGenerator::new(pil, &fixed_col_values, &*self.query_callback)
.with_external_witness_values(current_witness)
.with_challenges(stage, challenges)
.generate()
}
}
}

Expand Down

0 comments on commit ce15e57

Please sign in to comment.