Skip to content

Commit

Permalink
record the full PV line when searching by fixed depth, nodes, or time
Browse files Browse the repository at this point in the history
  • Loading branch information
brunocodutra committed Nov 24, 2024
1 parent 0ba033a commit d6bb1af
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 241 deletions.
4 changes: 2 additions & 2 deletions benches/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ fn bench(reps: u64, options: &Options, limits: &Limits) -> Duration {

for _ in 0..reps {
let mut e = Engine::with_options(options);
let interrupter = Trigger::armed();
let stopper = Trigger::armed();
let pos = Evaluator::default();
let timer = Instant::now();
e.search(&pos, limits, &interrupter);
e.search::<1>(&pos, limits, &stopper);
time += timer.elapsed();
}

Expand Down
4 changes: 2 additions & 2 deletions lib/nnue/hidden.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl<const N: usize> Hidden<N> {
#[inline(always)]
#[cfg(target_feature = "avx2")]
pub unsafe fn avx2(&self, us: &[i16; N], them: &[i16; N]) -> i32 {
const { assert!(N % 128 == 0) };
const { assert!(N % 128 == 0) }

use std::{arch::x86_64::*, mem::transmute};

Expand Down Expand Up @@ -73,7 +73,7 @@ impl<const N: usize> Hidden<N> {
#[inline(always)]
#[cfg(target_feature = "ssse3")]
pub unsafe fn sse(&self, us: &[i16; N], them: &[i16; N]) -> i32 {
const { assert!(N % 64 == 0) };
const { assert!(N % 64 == 0) }

use std::{arch::x86_64::*, mem::transmute};

Expand Down
133 changes: 63 additions & 70 deletions lib/search/driver.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use crate::search::{Interrupted, Pv, ThreadCount};
use crate::util::{Binary, Bits, Integer};
use derive_more::{Deref, From};
use crate::search::{Interrupted, Pv, Score, ThreadCount};
use crate::util::{Assume, Integer};
use crate::{chess::Move, nnue::Value};
use derive_more::From;
use rayon::{prelude::*, ThreadPool, ThreadPoolBuilder};
use std::sync::atomic::{AtomicU64, Ordering};
use std::cmp::max_by_key;
use std::sync::atomic::{AtomicI16, Ordering};

/// Whether the search should be [`Interrupted`] or exited early.
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, From)]
pub enum ControlFlow {
Interrupt(Interrupted),
Continue,
Break,
}

Expand All @@ -32,99 +35,89 @@ impl Driver {
///
/// The order in which elements are processed and on which thread is unspecified.
#[inline(always)]
pub fn drive<M, F>(&self, mut pv: Pv, moves: &[M], f: F) -> Result<Pv, Interrupted>
pub fn drive<F, const N: usize>(
&self,
mut head: Move,
mut tail: Pv<N>,
moves: &[(Move, Value)],
f: F,
) -> Result<(Move, Pv<N>), Interrupted>
where
M: Sync,
F: Fn(&Pv, &M) -> Result<Pv, ControlFlow> + Sync,
F: Fn(Score, Move, Value) -> Result<Pv<N>, ControlFlow> + Sync,
{
match self {
Self::Sequential => {
for m in moves.iter().rev() {
pv = match f(&pv, m) {
Ok(partial) => partial.max(pv),
for &(m, gain) in moves.iter().rev() {
match f(tail.score(), m, gain) {
Err(ControlFlow::Break) => break,
Err(ControlFlow::Continue) => continue,
Err(ControlFlow::Interrupt(e)) => return Err(e),
Ok(partial) => {
if partial > tail {
(head, tail) = (m, partial)
}
}
};
}

Ok(pv)
Ok((head, tail))
}

Self::Parallel(e) => e.install(|| {
use Ordering::Relaxed;
let pv = AtomicU64::new(IndexedPv(pv, u32::MAX).encode().get());
let result = moves.par_iter().enumerate().rev().try_for_each(|(idx, m)| {
let partial = f(&IndexedPv::decode(Bits::new(pv.load(Relaxed))), m)?;
pv.fetch_max(IndexedPv(partial, idx.saturate()).encode().get(), Relaxed);
Ok(())
});

if matches!(result, Ok(()) | Err(ControlFlow::Break)) {
Ok(*IndexedPv::decode(Bits::new(pv.into_inner())))
} else {
Err(Interrupted)
}
let score = AtomicI16::new(tail.score().get());
let (head, tail, _) = moves
.par_iter()
.enumerate()
.rev()
.map(
|(idx, &(m, gain))| match f(Score::new(score.load(Relaxed)), m, gain) {
Err(ControlFlow::Break) => None,
Err(ControlFlow::Continue) => Some(Ok(None)),
Err(ControlFlow::Interrupt(e)) => Some(Err(e)),
Ok(partial) => {
score.fetch_max(partial.score().get(), Relaxed);
Some(Ok(Some((m, partial, idx))))
}
},
)
.while_some()
.chain([Ok(Some((head, tail, usize::MAX)))])
.try_reduce(
|| None,
|a, b| {
Ok(max_by_key(a, b, |x| {
x.as_ref().map(|(_, t, i)| (t.score(), *i))
}))
},
)?
.assume();

Ok((head, tail))
}),
}
}
}

#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deref)]
#[cfg_attr(test, derive(test_strategy::Arbitrary))]
struct IndexedPv(#[deref] Pv, u32);

impl Binary for IndexedPv {
type Bits = Bits<u64, 64>;

#[inline(always)]
fn encode(&self) -> Self::Bits {
let mut bits = Bits::default();
bits.push(self.score().encode());
bits.push(Bits::<u32, 32>::new(self.1));
bits.push(self.deref().encode());
bits
}

#[inline(always)]
fn decode(mut bits: Self::Bits) -> Self {
let best = Binary::decode(bits.pop());
let idx = bits.pop::<u32, 32>().get();
let score = Binary::decode(bits.pop());
Self(Pv::new(score, best), idx)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{chess::Move, nnue::Value};
use std::cmp::max;
use test_strategy::proptest;

#[proptest]
fn decoding_encoded_indexed_pv_is_an_identity(pv: IndexedPv) {
assert_eq!(IndexedPv::decode(pv.encode()), pv);
}

#[proptest]
fn indexed_pv_with_higher_score_is_larger(a: Pv, b: Pv, i: u32) {
assert_eq!(a < b, IndexedPv(a, i) < IndexedPv(b, i));
}
fn driver_finds_pv(c: ThreadCount, h: Move, t: Pv<3>, ms: Vec<(Move, Value)>) {
let (head, tail, _) = ms
.iter()
.enumerate()
.map(|(i, &(m, v))| (m, Pv::new(v.saturate(), []), i))
.fold((h, t.clone(), usize::MAX), |a, b| {
max_by_key(a, b, |(_, t, i)| (t.score(), *i))
});

#[proptest]
fn indexed_pv_with_same_score_but_higher_index_is_larger(pv: Pv, a: u32, b: u32) {
assert_eq!(a < b, IndexedPv(pv, a) < IndexedPv(pv, b));
}

#[proptest]
fn driver_finds_max_indexed_pv(c: ThreadCount, pv: Pv, ms: Vec<(Move, Value)>) {
assert_eq!(
Driver::new(c).drive(pv, &ms, |_, &(m, v)| Ok(Pv::new(v.saturate(), Some(m)))),
Ok(*ms
.into_iter()
.enumerate()
.map(|(i, (m, v))| IndexedPv(Pv::new(v.saturate(), Some(m)), i as _))
.fold(IndexedPv(pv, u32::MAX), max))
Driver::new(c).drive(h, t, &ms, |_, _, v| Ok(Pv::new(v.saturate(), []))),
Ok((head, tail))
)
}
}
Loading

0 comments on commit d6bb1af

Please sign in to comment.