diff --git a/halo2_proofs/src/dev.rs b/halo2_proofs/src/dev.rs index ce8327974a..8439310b7d 100644 --- a/halo2_proofs/src/dev.rs +++ b/halo2_proofs/src/dev.rs @@ -7,6 +7,8 @@ use std::iter; use std::ops::{Add, Mul, Neg, Range}; use ff::Field; +use rayon::iter::IntoParallelRefIterator; +use rayon::iter::ParallelIterator; use crate::plonk::Assigned; use crate::{ @@ -670,6 +672,48 @@ impl Assignment for MockProver { } } +fn parallelizable_iter_flat_map( + row_indexes: &[usize], + f: Fun, +) -> Vec +where + Fun: Fn(usize) -> Vec, +{ + if ISPAR { + row_indexes.par_iter().flat_map(|row| f(*row)).collect() + } else { + row_indexes.iter().flat_map(|row| f(*row)).collect() + } +} + +fn parallelizable_iter_map( + row_indexes: &[usize], + f: Fun, +) -> Vec +where + Fun: Fn(usize) -> U, +{ + if ISPAR { + row_indexes.par_iter().map(|row| f(*row)).collect() + } else { + row_indexes.iter().map(|row| f(*row)).collect() + } +} + +fn parallelizable_iter_filter_map( + row_indexes: &[usize], + f: Fun, +) -> Vec +where + Fun: Fn(usize) -> Option, +{ + if ISPAR { + row_indexes.par_iter().filter_map(|row| f(*row)).collect() + } else { + row_indexes.iter().filter_map(|row| f(*row)).collect() + } +} + impl MockProver { /// Runs a synthetic keygen-and-prove operation on the given circuit, collecting data /// about the constraints and their assignments. @@ -759,6 +803,12 @@ impl MockProver { self.verify_at_rows(self.usable_rows.clone(), self.usable_rows.clone()) } + /// Returns `Ok(())` if this `MockProver` is satisfied, or a list of errors indicating + /// the reasons that the circuit is not satisfied. Run verify() in parallel. + pub fn verify_par(&self) -> Result<(), Vec> { + self.verify_at_rows_par(self.usable_rows.clone(), self.usable_rows.clone()) + } + /// Returns `Ok(())` if this `MockProver` is satisfied, or a list of errors indicating /// the reasons that the circuit is not satisfied. /// Constraints are only checked at `gate_row_ids`, @@ -767,6 +817,27 @@ impl MockProver { &self, gate_row_ids: I, lookup_input_row_ids: I, + ) -> Result<(), Vec> { + self.verify_at_rows_internal::<_, false>(gate_row_ids, lookup_input_row_ids) + } + + /// Returns `Ok(())` if this `MockProver` is satisfied, or a list of errors indicating + /// the reasons that the circuit is not satisfied. + /// Constraints are only checked at `gate_row_ids`, + /// and lookup inputs are only checked at `lookup_input_row_ids` + /// Run verify_at_rows() in parallel. + pub fn verify_at_rows_par>( + &self, + gate_row_ids: I, + lookup_input_row_ids: I, + ) -> Result<(), Vec> { + self.verify_at_rows_internal::<_, true>(gate_row_ids, lookup_input_row_ids) + } + + fn verify_at_rows_internal, const ISPAR: bool>( + &self, + gate_row_ids: I, + lookup_input_row_ids: I, ) -> Result<(), Vec> { let n = self.n as i32; @@ -800,75 +871,77 @@ impl MockProver { .enumerate() .filter(move |(_, g)| g.queried_selectors().contains(selector)) .flat_map(move |(gate_index, gate)| { - at.iter().flat_map(move |selector_row| { - // Selectors are queried with no rotation. - let gate_row = *selector_row as i32; - - gate.queried_cells().iter().filter_map(move |cell| { - // Determine where this cell should have been assigned. - let cell_row = ((gate_row + n + cell.rotation.0) % n) as usize; - - // Check that it was assigned! - if r.is_assigned(cell.column, cell_row) { - None - } else { - Some(VerifyFailure::CellNotAssigned { - gate: (gate_index, gate.name()).into(), - region: (r_i, r.name.clone()).into(), - column: cell.column, - offset: cell_row as isize - r.rows.unwrap().0 as isize, - }) - } - }) + parallelizable_iter_flat_map::<_, _, ISPAR>(at, |gate_row| { + let ret: Vec = gate + .queried_cells() + .iter() + .filter_map(move |cell| { + // Determine where this cell should have been assigned. + let cell_row = + (((gate_row as i32) + n + cell.rotation.0) % n) as usize; + + // Check that it was assigned! + if r.is_assigned(cell.column, cell_row) { + None + } else { + Some(VerifyFailure::CellNotAssigned { + gate: (gate_index, gate.name()).into(), + region: (r_i, r.name.clone()).into(), + column: cell.column, + offset: cell_row as isize - r.rows.unwrap().0 as isize, + }) + } + }) + .collect(); + ret }) }) }) }); // Check that all gates are satisfied for all rows. - let gate_errors = - self.cs - .gates - .iter() - .enumerate() - .flat_map(|(gate_index, gate)| { - let blinding_rows = - (self.n as usize - (self.cs.blinding_factors() + 1))..(self.n as usize); - (gate_row_ids - .clone() - .into_iter() - .chain(blinding_rows.into_iter())) - .flat_map(move |row| { - fn load_instance<'a, F: FieldExt, T: ColumnType>( - n: i32, - row: i32, - queries: &'a [(Column, Rotation)], - cells: &'a [Vec], - ) -> impl Fn(usize, usize, Rotation) -> Value + 'a - { - move |index, _, _| { - let (column, at) = &queries[index]; - let resolved_row = (row + n + at.0) % n; - Value::Real(cells[column.index()][resolved_row as usize]) - } - } + let blinding_rows = (self.n as usize - (self.cs.blinding_factors() + 1))..(self.n as usize); + let indexes: Vec = + (gate_row_ids.into_iter().chain(blinding_rows.into_iter())).collect(); + let gate_errors = self + .cs + .gates + .iter() + .enumerate() + .flat_map(|(gate_index, gate)| { + fn load_instance<'a, F: FieldExt, T: ColumnType>( + n: i32, + row: i32, + queries: &'a [(Column, Rotation)], + cells: &'a [Vec], + ) -> impl Fn(usize, usize, Rotation) -> Value + 'a { + move |index, _, _| { + let (column, at) = &queries[index]; + let resolved_row = (row + n + at.0) % n; + Value::Real(cells[column.index()][resolved_row as usize]) + } + } - fn load<'a, F: FieldExt, T: ColumnType>( - n: i32, - row: i32, - queries: &'a [(Column, Rotation)], - cells: &'a [Vec>], - ) -> impl Fn(usize, usize, Rotation) -> Value + 'a - { - move |index, _, _| { - let (column, at) = &queries[index]; - let resolved_row = (row + n + at.0) % n; - cells[column.index()][resolved_row as usize].into() - } - } - let row = row as i32; - gate.polynomials().iter().enumerate().filter_map( - move |(poly_index, poly)| match poly.evaluate_lazy( + fn load<'a, F: FieldExt, T: ColumnType>( + n: i32, + row: i32, + queries: &'a [(Column, Rotation)], + cells: &'a [Vec>], + ) -> impl Fn(usize, usize, Rotation) -> Value + 'a { + move |index, _, _| { + let (column, at) = &queries[index]; + let resolved_row = (row + n + at.0) % n; + cells[column.index()][resolved_row as usize].into() + } + } + parallelizable_iter_flat_map::<_, _, ISPAR>(&indexes, |row| { + let row = row as i32; + let ret: Vec = gate + .polynomials() + .iter() + .enumerate() + .filter_map(move |(poly_index, poly)| { + match poly.evaluate_lazy( &|scalar| Value::Real(scalar), &|_| panic!("virtual selectors are removed during optimization"), &load(n, row, &self.cs.fixed_queries, &self.fixed), @@ -915,10 +988,12 @@ impl MockProver { ) .into(), }), - }, - ) - }) - }); + } + }) + .collect(); + ret + }) + }); // Check that all lookups exist in their respective tables. let lookup_errors = @@ -927,7 +1002,7 @@ impl MockProver { .iter() .enumerate() .flat_map(|(lookup_index, lookup)| { - let load = |expression: &Expression, row| { + let load = |expression: &Expression, row: i32| { expression.evaluate_lazy( &|scalar| Value::Real(scalar), &|_| panic!("virtual selectors are removed during optimization"), @@ -935,8 +1010,7 @@ impl MockProver { let query = self.cs.fixed_queries[index]; let column_index = query.0.index(); let rotation = query.1 .0; - self.fixed[column_index] - [(row as i32 + n + rotation) as usize % n as usize] + self.fixed[column_index][(row + n + rotation) as usize % n as usize] .into() }, &|index, _, _| { @@ -944,7 +1018,7 @@ impl MockProver { let column_index = query.0.index(); let rotation = query.1 .0; self.advice[column_index] - [(row as i32 + n + rotation) as usize % n as usize] + [(row + n + rotation) as usize % n as usize] .into() }, &|index, _, _| { @@ -953,7 +1027,7 @@ impl MockProver { let rotation = query.1 .0; Value::Real( self.instance[column_index] - [(row as i32 + n + rotation) as usize % n as usize], + [(row + n + rotation) as usize % n as usize], ) }, &|a| -a, @@ -966,42 +1040,40 @@ impl MockProver { // In the real prover, the lookup expressions are never enforced on // unusable rows, due to the (1 - (l_last(X) + l_blind(X))) term. - let table: std::collections::BTreeSet> = self - .usable_rows - .clone() - .map(|table_row| { + let usable_row_vec: Vec<_> = self.usable_rows.clone().into_iter().collect(); + let table = + parallelizable_iter_map::<_, _, ISPAR>(&usable_row_vec, |table_row| { lookup .table_expressions .iter() - .map(move |c| load(c, table_row)) + .map(move |c| load(c, table_row as i32)) .collect::>() - }) - .collect(); - lookup_input_row_ids - .clone() - .into_iter() - .filter_map(move |input_row| { - let inputs: Vec<_> = lookup - .input_expressions - .iter() - .map(|c| load(c, input_row)) - .collect(); - let lookup_passes = table.contains(&inputs); - if lookup_passes { - None - } else { - Some(VerifyFailure::Lookup { - name: lookup.name, - lookup_index, - location: FailureLocation::find_expressions( - &self.cs, - &self.regions, - input_row, - lookup.input_expressions.iter(), - ), - }) - } - }) + }); + let lookup_input_row_id_vec: Vec<_> = lookup_input_row_ids.clone().collect(); + parallelizable_iter_map::<_, _, ISPAR>(&lookup_input_row_id_vec, |input_row| { + let inputs: Vec<_> = lookup + .input_expressions + .iter() + .map(|c| load(c, input_row as i32)) + .collect(); + let lookup_passes = table.contains(&inputs); + if lookup_passes { + None + } else { + Some(VerifyFailure::Lookup { + name: lookup.name, + lookup_index, + location: FailureLocation::find_expressions( + &self.cs, + &self.regions, + input_row as usize, + lookup.input_expressions.iter(), + ), + }) + } + }) + .into_iter() + .flatten() }); // Check that permutations preserve the original values of the cells. @@ -1028,19 +1100,28 @@ impl MockProver { .flat_map(move |(column, values)| { // Iterate over each row of the column to check that the cell's // value is preserved by the mapping. - values.iter().enumerate().filter_map(move |(row, cell)| { - let original_cell = original(column, row); - let permuted_cell = original(cell.0, cell.1); - if original_cell == permuted_cell { - None - } else { - Some(VerifyFailure::Permutation { - column: (*self.cs.permutation.get_columns().get(column).unwrap()) + let indexes: Vec = (0..values.len()).into_iter().collect(); + let ret: Vec = + parallelizable_iter_filter_map::<_, _, ISPAR>(&indexes, move |row| { + let cell = values[row]; + let original_cell = original(column, row); + let permuted_cell = original(cell.0, cell.1); + if original_cell == permuted_cell { + None + } else { + Some(VerifyFailure::Permutation { + column: (*self + .cs + .permutation + .get_columns() + .get(column) + .unwrap()) .into(), - row, - }) - } - }) + row, + }) + } + }); + ret }) };