Skip to content

Commit

Permalink
Fix parallel function signature.
Browse files Browse the repository at this point in the history
  • Loading branch information
spherel committed Aug 7, 2022
1 parent 0b83e58 commit 55bae7c
Showing 1 changed file with 52 additions and 39 deletions.
91 changes: 52 additions & 39 deletions halo2_proofs/src/dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,34 +672,45 @@ impl<F: Field + Group> Assignment<F> for MockProver<F> {
}
}

fn home_made_par_iter_flat_map<U: Send, Fun: Send + Sync, const ISPAR: bool>(
fn parallelizable_iter_flat_map<U: Send, Fun: Send + Sync, const ISPAR: bool>(
row_indexes: &[usize],
f: Fun,
) -> Vec<U>
where
Fun: Fn(i32) -> Vec<U>,
Fun: Fn(usize) -> Vec<U>,
{
if ISPAR {
row_indexes
.par_iter()
.flat_map(|row| f(*row as i32))
.collect()
row_indexes.par_iter().flat_map(|row| f(*row)).collect()
} else {
row_indexes.iter().flat_map(|row| f(*row as i32)).collect()
row_indexes.iter().flat_map(|row| f(*row)).collect()
}
}

fn home_made_par_iter_map<U: Send + Default, Fun: Send + Sync, const ISPAR: bool>(
fn parallelizable_iter_map<U: Send + Default, Fun: Send + Sync, const ISPAR: bool>(
row_indexes: &[usize],
f: Fun,
) -> Vec<U>
where
Fun: Fn(i32) -> U,
Fun: Fn(usize) -> U,
{
if ISPAR {
row_indexes.par_iter().map(|row| f(*row as i32)).collect()
row_indexes.par_iter().map(|row| f(*row)).collect()
} else {
row_indexes.iter().map(|row| f(*row as i32)).collect()
row_indexes.iter().map(|row| f(*row)).collect()
}
}

fn parallelizable_iter_filter_map<U: Send, Fun: Send + Sync, const ISPAR: bool>(
row_indexes: &[usize],
f: Fun,
) -> Vec<U>
where
Fun: Fn(usize) -> Option<U>,
{
if ISPAR {
row_indexes.par_iter().filter_map(|row| f(*row)).collect()
} else {
row_indexes.iter().filter_map(|row| f(*row)).collect()
}
}

Expand Down Expand Up @@ -860,13 +871,14 @@ impl<F: FieldExt> MockProver<F> {
.enumerate()
.filter(move |(_, g)| g.queried_selectors().contains(selector))
.flat_map(move |(gate_index, gate)| {
home_made_par_iter_flat_map::<_, _, ISPAR>(at, |gate_row| {
parallelizable_iter_flat_map::<_, _, ISPAR>(at, |gate_row| {
let ret: Vec<VerifyFailure> = gate
.queried_cells()
.iter()
.filter_map(move |cell| {
// Determine where this cell should have been assigned.
let cell_row = ((gate_row + n + cell.rotation.0) % n) as usize;
let cell_row =
(((gate_row as i32) + n + cell.rotation.0) % n) as usize;

// Check that it was assigned!
if r.is_assigned(cell.column, cell_row) {
Expand Down Expand Up @@ -922,7 +934,8 @@ impl<F: FieldExt> MockProver<F> {
cells[column.index()][resolved_row as usize].into()
}
}
home_made_par_iter_flat_map::<_, _, ISPAR>(&indexes, |row| {
parallelizable_iter_flat_map::<_, _, ISPAR>(&indexes, |row| {
let row = row as i32;
let ret: Vec<VerifyFailure> = gate
.polynomials()
.iter()
Expand Down Expand Up @@ -1014,7 +1027,7 @@ impl<F: FieldExt> MockProver<F> {
let rotation = query.1 .0;
Value::Real(
self.instance[column_index]
[(row as i32 + n + rotation) as usize % n as usize],
[(row + n + rotation) as usize % n as usize],
)
},
&|a| -a,
Expand All @@ -1029,19 +1042,19 @@ impl<F: FieldExt> MockProver<F> {
// unusable rows, due to the (1 - (l_last(X) + l_blind(X))) term.
let usable_row_vec: Vec<_> = self.usable_rows.clone().into_iter().collect();
let table =
home_made_par_iter_map::<_, _, ISPAR>(&usable_row_vec, |table_row| {
parallelizable_iter_map::<_, _, ISPAR>(&usable_row_vec, |table_row| {
lookup
.table_expressions
.iter()
.map(move |c| load(c, table_row))
.map(move |c| load(c, table_row as i32))
.collect::<Vec<_>>()
});
let lookup_input_row_id_vec: Vec<_> = lookup_input_row_ids.clone().collect();
home_made_par_iter_map::<_, _, ISPAR>(&lookup_input_row_id_vec, |input_row| {
parallelizable_iter_map::<_, _, ISPAR>(&lookup_input_row_id_vec, |input_row| {
let inputs: Vec<_> = lookup
.input_expressions
.iter()
.map(|c| load(c, input_row))
.map(|c| load(c, input_row as i32))
.collect();
let lookup_passes = table.contains(&inputs);
if lookup_passes {
Expand Down Expand Up @@ -1087,27 +1100,27 @@ impl<F: FieldExt> MockProver<F> {
.flat_map(move |(column, values)| {
// Iterate over each row of the column to check that the cell's
// value is preserved by the mapping.
let f = move |row, cell: (usize, usize)| -> Option<VerifyFailure> {
let original_cell = original(column, row);
let permuted_cell = original(cell.0, cell.1);
if original_cell == permuted_cell {
None
} else {
Some(VerifyFailure::Permutation {
column: (*self.cs.permutation.get_columns().get(column).unwrap())
.into(),
row,
})
}
};
let indexes: Vec<usize> = (0..values.len()).into_iter().collect();
let ret: Vec<VerifyFailure> = indexes
.par_iter()
.filter_map(|row| {
let cell = values[*row];
f(*row, cell)
})
.collect();
let ret: Vec<VerifyFailure> =
parallelizable_iter_filter_map::<_, _, ISPAR>(&indexes, move |row| {
let cell = values[row];
let original_cell = original(column, row);
let permuted_cell = original(cell.0, cell.1);
if original_cell == permuted_cell {
None
} else {
Some(VerifyFailure::Permutation {
column: (*self
.cs
.permutation
.get_columns()
.get(column)
.unwrap())
.into(),
row,
})
}
});
ret
})
};
Expand Down

0 comments on commit 55bae7c

Please sign in to comment.