Skip to content

Commit

Permalink
Prototype hand-written witgen
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Dec 3, 2024
1 parent b2c0812 commit 0f9d13e
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 5 deletions.
182 changes: 182 additions & 0 deletions executor/src/witgen/bus_accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use std::collections::BTreeMap;

use powdr_ast::analyzed::Analyzed;
use powdr_number::FieldElement;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

pub fn generate_bus_accumulators<T: FieldElement>(
_pil: &Analyzed<T>,
witness_columns: &[(String, Vec<T>)],
_fixed_columns: Vec<(String, &[T])>,
_challenges: BTreeMap<u64, T>,
) -> Vec<(String, Vec<T>)> {
let accumulators = (0..31)
.into_par_iter()
.map(|i| interaction_columns(i, witness_columns))
.collect::<Vec<_>>();

witness_columns
.to_vec()
.into_iter()
.chain(accumulators.into_iter().flatten())
.collect()
}

fn interaction_columns<T: FieldElement>(
connection_index: usize,
witness_columns: &[(String, Vec<T>)],
) -> Vec<(String, Vec<T>)> {
let tuple_size = if connection_index == 0 {
// Simulate PC lookup
700
} else {
1 + connection_index / 10
};

// Pick random indices
let indices = (0..tuple_size)
.map(|j| {
((42usize
.wrapping_mul(connection_index * 70 + j)
.wrapping_add(123))
% 17482394)
% witness_columns.len()
})
.collect::<Vec<_>>();

let size = witness_columns[0].1.len();
let mut acc1 = vec![T::zero(); size];
let mut acc2 = vec![T::zero(); size];
let mut acc1_next = vec![T::zero(); size];
let mut acc2_next = vec![T::zero(); size];

for i in 0..size {
let current_acc = if i == 0 {
(T::zero(), T::zero())
} else {
(acc1[i - 1], acc2[i - 1])
};

let tuple = indices
.iter()
.map(|&j| witness_columns[j].1[i])
.collect::<Vec<_>>();

let fingerprint = add_ext(
fingerprint(&tuple, (T::from(1234), T::from(12345))),
(T::from(8764), T::from(876324)),
);
let multiplicity = T::from(42);

/*
add_ext(
current_acc,
mul_ext(m_ext_next, inv_ext(folded_next))
)
*/
let update = add_ext(
current_acc,
mul_ext((multiplicity, T::from(0)), inv_ext(fingerprint)),
);

acc1[i] = update.0;
acc2[i] = update.1;
acc1_next[(i + size - 1) % size] = update.0;
acc2_next[(i + size - 1) % size] = update.1;
}

vec![
(name("main::acc", connection_index * 2), acc1),
(name("main::acc", connection_index * 2 + 1), acc2),
(name("main::acc_next", connection_index * 2), acc1_next),
(name("main::acc_next", connection_index * 2 + 1), acc2_next),
]
}

fn name(base: &str, i: usize) -> String {
if i == 0 {
return base.to_string();
}
format!("{base}_{}", i)
}

/*
let<T: Add + FromLiteral + Mul> mul_ext: Fp2<T>, Fp2<T> -> Fp2<T> = |a, b| match (a, b) {
(Fp2::Fp2(a0, a1), Fp2::Fp2(b0, b1)) => Fp2::Fp2(
// Multiplication modulo the polynomial x^2 - 11. We'll use the fact
// that x^2 == 11 (mod x^2 - 11), so:
// (a0 + a1 * x) * (b0 + b1 * x) = a0 * b0 + 11 * a1 * b1 + (a1 * b0 + a0 * b1) * x (mod x^2 - 11)
a0 * b0 + 11 * a1 * b1,
a1 * b0 + a0 * b1
)
};
*/

fn mul_ext<T: FieldElement>(a: (T, T), b: (T, T)) -> (T, T) {
(a.0 * b.0 + a.1 * b.1 * T::from(11), a.1 * b.0 + a.0 * b.1)
}

fn add_ext<T: FieldElement>(a: (T, T), b: (T, T)) -> (T, T) {
(a.0 + b.0, a.1 + b.1)
}

fn sub_ext<T: FieldElement>(a: (T, T), b: (T, T)) -> (T, T) {
(a.0 - b.0, a.1 - b.1)
}

/*
/// Maps [x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$
/// To generate an expression that computes the fingerprint, use `fingerprint_inter` instead.
/// Note that alpha is passed as an expressions, so that it is only evaluated if needed (i.e., if len(expr_array) > 1).
let fingerprint: fe[], Fp2<expr> -> Fp2<fe> = query |expr_array, alpha| {
fingerprint_impl(expr_array, eval_ext(alpha), len(expr_array))
};
let fingerprint_impl: fe[], Fp2<fe>, int -> Fp2<fe> = query |expr_array, alpha, l| if l == 1 {
// Base case
from_base(expr_array[0])
} else {
// Recursively compute the fingerprint as fingerprint(expr_array[:-1], alpha) * alpha + expr_array[-1]
let intermediate_fingerprint = fingerprint_impl(expr_array, alpha, l - 1);
add_ext(mul_ext(alpha, intermediate_fingerprint), from_base(expr_array[l - 1]))
};
*/

fn fingerprint<T: FieldElement>(expr_array: &[T], alpha: (T, T)) -> (T, T) {
fingerprint_impl(expr_array, alpha, expr_array.len())
}

fn fingerprint_impl<T: FieldElement>(expr_array: &[T], alpha: (T, T), l: usize) -> (T, T) {
if l == 1 {
return (expr_array[0], T::zero());
}

let intermediate_fingerprint = fingerprint_impl(expr_array, alpha, l - 1);
add_ext(
mul_ext(alpha, intermediate_fingerprint),
(expr_array[l - 1], T::zero()),
)
}

/*
/// Extension field inversion
let inv_ext: Fp2<fe> -> Fp2<fe> = |a| match a {
// The inverse of (a0, a1) is a point (b0, b1) such that:
// (a0 + a1 * x) (b0 + b1 * x) = 1 (mod x^2 - 11)
// Multiplying out and plugging in x^2 = 11 yields the following system of linear equations:
// a0 * b0 + 11 * a1 * b1 = 1
// a1 * b0 + a0 * b1 = 0
// Solving for (b0, b1) yields:
Fp2::Fp2(a0, a1) => {
let factor = inv_field(11 * a1 * a1 - a0 * a0);
Fp2::Fp2(-a0 * factor, a1 * factor)
}
};
*/
fn inv_ext<T: FieldElement>(a: (T, T)) -> (T, T) {
let factor = T::from(1) / (T::from(11) * a.1 * a.1 - a.0 * a.0);
(-a.0 * factor, a.1 * factor)
}
34 changes: 29 additions & 5 deletions executor/src/witgen/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::sync::Arc;

use bus_accumulator::generate_bus_accumulators;
use itertools::Itertools;
use machines::machine_extractor::MachineExtractor;
use powdr_ast::analyzed::{
Expand All @@ -27,6 +28,7 @@ use self::machines::profiling::{record_end, record_start, reset_and_print_profil
mod affine_expression;
pub(crate) mod analysis;
mod block_processor;
mod bus_accumulator;
mod data_structures;
mod eval_result;
mod expression_evaluator;
Expand Down Expand Up @@ -97,6 +99,25 @@ impl<T: FieldElement> WitgenCallbackContext<T> {
.collect()
}

pub fn select_fixed_columns2(
&self,
pil: &Analyzed<T>,
size: DegreeType,
) -> Vec<(String, &[T])> {
// The provided PIL might only contain a subset of all fixed columns.
let fixed_column_names = pil
.constant_polys_in_source_order()
.flat_map(|(symbol, _)| symbol.array_elements())
.map(|(name, _)| name.clone())
.collect::<BTreeSet<_>>();
// Select the columns in the current PIL and select the right size.
self.fixed_col_values
.iter()
.filter(|(n, _)| fixed_column_names.contains(n))
.map(|(n, v)| (n.clone(), v.get_by_size(size).unwrap()))
.collect()
}

/// Computes the next-stage witness, given the current witness and challenges.
/// All columns in the provided PIL are expected to have the same size.
/// Typically, this function should be called once per machine.
Expand All @@ -108,11 +129,14 @@ impl<T: FieldElement> WitgenCallbackContext<T> {
stage: u8,
) -> Vec<(String, Vec<T>)> {
let size = current_witness.iter().next().unwrap().1.len() as DegreeType;
let fixed_col_values = self.select_fixed_columns(pil, size);
WitnessGenerator::new(pil, &fixed_col_values, &*self.query_callback)
.with_external_witness_values(current_witness)
.with_challenges(stage, challenges)
.generate()
let fixed_col_values = self.select_fixed_columns2(pil, size);

// WitnessGenerator::new(pil, &fixed_col_values, &*self.query_callback)
// .with_external_witness_values(current_witness)
// .with_challenges(stage, challenges)
// .generate()
assert_eq!(stage, 1);
generate_bus_accumulators(pil, current_witness, fixed_col_values, challenges)
}
}

Expand Down

0 comments on commit 0f9d13e

Please sign in to comment.