From d6bb1af4de1a4e52454fa722c06dc897cb89d2ed Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Sat, 23 Nov 2024 19:31:27 +0100 Subject: [PATCH] record the full PV line when searching by fixed depth, nodes, or time --- benches/search.rs | 4 +- lib/nnue/hidden.rs | 4 +- lib/search/driver.rs | 133 +++++++++--------- lib/search/engine.rs | 269 ++++++++++++++++++++---------------- lib/search/pv.rs | 107 ++++++++++---- lib/search/transposition.rs | 4 +- lib/uci.rs | 41 +++--- 7 files changed, 321 insertions(+), 241 deletions(-) diff --git a/benches/search.rs b/benches/search.rs index ab442141..d27dae6c 100644 --- a/benches/search.rs +++ b/benches/search.rs @@ -14,10 +14,10 @@ fn bench(reps: u64, options: &Options, limits: &Limits) -> Duration { for _ in 0..reps { let mut e = Engine::with_options(options); - let interrupter = Trigger::armed(); + let stopper = Trigger::armed(); let pos = Evaluator::default(); let timer = Instant::now(); - e.search(&pos, limits, &interrupter); + e.search::<1>(&pos, limits, &stopper); time += timer.elapsed(); } diff --git a/lib/nnue/hidden.rs b/lib/nnue/hidden.rs index 77ca2a63..0384dd84 100644 --- a/lib/nnue/hidden.rs +++ b/lib/nnue/hidden.rs @@ -15,7 +15,7 @@ impl Hidden { #[inline(always)] #[cfg(target_feature = "avx2")] pub unsafe fn avx2(&self, us: &[i16; N], them: &[i16; N]) -> i32 { - const { assert!(N % 128 == 0) }; + const { assert!(N % 128 == 0) } use std::{arch::x86_64::*, mem::transmute}; @@ -73,7 +73,7 @@ impl Hidden { #[inline(always)] #[cfg(target_feature = "ssse3")] pub unsafe fn sse(&self, us: &[i16; N], them: &[i16; N]) -> i32 { - const { assert!(N % 64 == 0) }; + const { assert!(N % 64 == 0) } use std::{arch::x86_64::*, mem::transmute}; diff --git a/lib/search/driver.rs b/lib/search/driver.rs index 05193576..198669df 100644 --- a/lib/search/driver.rs +++ b/lib/search/driver.rs @@ -1,13 +1,16 @@ -use crate::search::{Interrupted, Pv, ThreadCount}; -use crate::util::{Binary, Bits, Integer}; -use derive_more::{Deref, From}; +use crate::search::{Interrupted, Pv, Score, ThreadCount}; +use crate::util::{Assume, Integer}; +use crate::{chess::Move, nnue::Value}; +use derive_more::From; use rayon::{prelude::*, ThreadPool, ThreadPoolBuilder}; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::cmp::max_by_key; +use std::sync::atomic::{AtomicI16, Ordering}; /// Whether the search should be [`Interrupted`] or exited early. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, From)] pub enum ControlFlow { Interrupt(Interrupted), + Continue, Break, } @@ -32,99 +35,89 @@ impl Driver { /// /// The order in which elements are processed and on which thread is unspecified. #[inline(always)] - pub fn drive(&self, mut pv: Pv, moves: &[M], f: F) -> Result + pub fn drive( + &self, + mut head: Move, + mut tail: Pv, + moves: &[(Move, Value)], + f: F, + ) -> Result<(Move, Pv), Interrupted> where - M: Sync, - F: Fn(&Pv, &M) -> Result + Sync, + F: Fn(Score, Move, Value) -> Result, ControlFlow> + Sync, { match self { Self::Sequential => { - for m in moves.iter().rev() { - pv = match f(&pv, m) { - Ok(partial) => partial.max(pv), + for &(m, gain) in moves.iter().rev() { + match f(tail.score(), m, gain) { Err(ControlFlow::Break) => break, + Err(ControlFlow::Continue) => continue, Err(ControlFlow::Interrupt(e)) => return Err(e), + Ok(partial) => { + if partial > tail { + (head, tail) = (m, partial) + } + } }; } - Ok(pv) + Ok((head, tail)) } Self::Parallel(e) => e.install(|| { use Ordering::Relaxed; - let pv = AtomicU64::new(IndexedPv(pv, u32::MAX).encode().get()); - let result = moves.par_iter().enumerate().rev().try_for_each(|(idx, m)| { - let partial = f(&IndexedPv::decode(Bits::new(pv.load(Relaxed))), m)?; - pv.fetch_max(IndexedPv(partial, idx.saturate()).encode().get(), Relaxed); - Ok(()) - }); - - if matches!(result, Ok(()) | Err(ControlFlow::Break)) { - Ok(*IndexedPv::decode(Bits::new(pv.into_inner()))) - } else { - Err(Interrupted) - } + let score = AtomicI16::new(tail.score().get()); + let (head, tail, _) = moves + .par_iter() + .enumerate() + .rev() + .map( + |(idx, &(m, gain))| match f(Score::new(score.load(Relaxed)), m, gain) { + Err(ControlFlow::Break) => None, + Err(ControlFlow::Continue) => Some(Ok(None)), + Err(ControlFlow::Interrupt(e)) => Some(Err(e)), + Ok(partial) => { + score.fetch_max(partial.score().get(), Relaxed); + Some(Ok(Some((m, partial, idx)))) + } + }, + ) + .while_some() + .chain([Ok(Some((head, tail, usize::MAX)))]) + .try_reduce( + || None, + |a, b| { + Ok(max_by_key(a, b, |x| { + x.as_ref().map(|(_, t, i)| (t.score(), *i)) + })) + }, + )? + .assume(); + + Ok((head, tail)) }), } } } -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deref)] -#[cfg_attr(test, derive(test_strategy::Arbitrary))] -struct IndexedPv(#[deref] Pv, u32); - -impl Binary for IndexedPv { - type Bits = Bits; - - #[inline(always)] - fn encode(&self) -> Self::Bits { - let mut bits = Bits::default(); - bits.push(self.score().encode()); - bits.push(Bits::::new(self.1)); - bits.push(self.deref().encode()); - bits - } - - #[inline(always)] - fn decode(mut bits: Self::Bits) -> Self { - let best = Binary::decode(bits.pop()); - let idx = bits.pop::().get(); - let score = Binary::decode(bits.pop()); - Self(Pv::new(score, best), idx) - } -} - #[cfg(test)] mod tests { use super::*; use crate::{chess::Move, nnue::Value}; - use std::cmp::max; use test_strategy::proptest; #[proptest] - fn decoding_encoded_indexed_pv_is_an_identity(pv: IndexedPv) { - assert_eq!(IndexedPv::decode(pv.encode()), pv); - } - - #[proptest] - fn indexed_pv_with_higher_score_is_larger(a: Pv, b: Pv, i: u32) { - assert_eq!(a < b, IndexedPv(a, i) < IndexedPv(b, i)); - } + fn driver_finds_pv(c: ThreadCount, h: Move, t: Pv<3>, ms: Vec<(Move, Value)>) { + let (head, tail, _) = ms + .iter() + .enumerate() + .map(|(i, &(m, v))| (m, Pv::new(v.saturate(), []), i)) + .fold((h, t.clone(), usize::MAX), |a, b| { + max_by_key(a, b, |(_, t, i)| (t.score(), *i)) + }); - #[proptest] - fn indexed_pv_with_same_score_but_higher_index_is_larger(pv: Pv, a: u32, b: u32) { - assert_eq!(a < b, IndexedPv(pv, a) < IndexedPv(pv, b)); - } - - #[proptest] - fn driver_finds_max_indexed_pv(c: ThreadCount, pv: Pv, ms: Vec<(Move, Value)>) { assert_eq!( - Driver::new(c).drive(pv, &ms, |_, &(m, v)| Ok(Pv::new(v.saturate(), Some(m)))), - Ok(*ms - .into_iter() - .enumerate() - .map(|(i, (m, v))| IndexedPv(Pv::new(v.saturate(), Some(m)), i as _)) - .fold(IndexedPv(pv, u32::MAX), max)) + Driver::new(c).drive(h, t, &ms, |_, _, v| Ok(Pv::new(v.saturate(), []))), + Ok((head, tail)) ) } } diff --git a/lib/search/engine.rs b/lib/search/engine.rs index 6190c028..b97e87c4 100644 --- a/lib/search/engine.rs +++ b/lib/search/engine.rs @@ -1,6 +1,7 @@ +use crate::chess::{Move, Outcome}; use crate::nnue::{Evaluator, Value}; -use crate::util::{Assume, Counter, Integer, Timer, Trigger}; -use crate::{chess::Outcome, search::*}; +use crate::search::*; +use crate::util::{Counter, Integer, Timer, Trigger}; use arrayvec::ArrayVec; use std::{cell::RefCell, ops::Range, time::Duration}; @@ -44,24 +45,29 @@ impl Engine { } /// Records a `[Transposition`]. - fn record(&self, pos: &Evaluator, bounds: Range, depth: Depth, ply: Ply, pv: Pv) -> Pv { - let m = pv.assume(); - if pv >= bounds.end && m.is_quiet() { - Self::KILLERS.with_borrow_mut(|ks| ks.insert(ply, pos.turn(), m)); + fn record( + &self, + pos: &Evaluator, + bounds: Range, + depth: Depth, + ply: Ply, + best: Move, + score: Score, + ) { + if score >= bounds.end && best.is_quiet() { + Self::KILLERS.with_borrow_mut(|ks| ks.insert(ply, pos.turn(), best)); } self.tt.set( pos.zobrist(), - if pv.score() >= bounds.end { - Transposition::lower(depth - ply, pv.score().normalize(-ply), m) - } else if pv.score() <= bounds.start { - Transposition::upper(depth - ply, pv.score().normalize(-ply), m) + if score >= bounds.end { + Transposition::lower(depth - ply, score.normalize(-ply), best) + } else if score <= bounds.start { + Transposition::upper(depth - ply, score.normalize(-ply), best) } else { - Transposition::exact(depth - ply, pv.score().normalize(-ply), m) + Transposition::exact(depth - ply, score.normalize(-ply), best) }, ); - - pv } /// An implementation of [mate distance pruning]. @@ -108,32 +114,49 @@ impl Engine { Some(depth - r - (depth - ply) / 4) } - /// A [zero-window] alpha-beta search. + /// The [alpha-beta] search. + /// + /// [alpha-beta]: https://www.chessprogramming.org/Alpha-Beta + fn ab( + &self, + pos: &Evaluator, + bounds: Range, + depth: Depth, + ply: Ply, + ctrl: &Control, + ) -> Result, Interrupted> { + if ply < N && depth > ply && bounds.start + 1 < bounds.end { + self.pvs(pos, bounds, depth, ply, ctrl) + } else { + Ok(self.pvs::<0>(pos, bounds, depth, ply, ctrl)?.convert()) + } + } + + /// The [zero-window] alpha-beta search. /// /// [zero-window]: https://www.chessprogramming.org/Null_Window - fn nw( + fn nw( &self, pos: &Evaluator, beta: Score, depth: Depth, ply: Ply, ctrl: &Control, - ) -> Result { - self.pvs(pos, beta - 1..beta, depth, ply, ctrl) + ) -> Result, Interrupted> { + self.ab(pos, beta - 1..beta, depth, ply, ctrl) } - /// An implementation of the [PVS] variation of [alpha-beta pruning] algorithm. + /// An implementation of the [PVS] variation of the alpha-beta search. /// /// [PVS]: https://www.chessprogramming.org/Principal_Variation_Search - /// [alpha-beta pruning]: https://www.chessprogramming.org/Alpha-Beta - fn pvs( + fn pvs( &self, pos: &Evaluator, bounds: Range, depth: Depth, ply: Ply, ctrl: &Control, - ) -> Result { + ) -> Result, Interrupted> { debug_assert!(!bounds.is_empty()); ctrl.interrupted()?; @@ -141,12 +164,12 @@ impl Engine { let (alpha, beta) = match pos.outcome() { None => self.mdp(ply, &bounds), Some(Outcome::DrawByThreefoldRepetition) if is_root => self.mdp(ply, &bounds), - Some(o) if o.is_draw() => return Ok(Pv::new(Score::new(0), None)), - Some(_) => return Ok(Pv::new(Score::lower().normalize(ply), None)), + Some(o) if o.is_draw() => return Ok(Pv::new(Score::new(0), [])), + Some(_) => return Ok(Pv::new(Score::lower().normalize(ply), [])), }; if alpha >= beta { - return Ok(Pv::new(alpha, None)); + return Ok(Pv::new(alpha, [])); } let transposition = self.tt.get(pos.zobrist()); @@ -164,43 +187,43 @@ impl Engine { _ => depth, }; + let transposed = match transposition { + None => Pv::new(pos.evaluate().saturate(), []), + Some(t) => t.transpose(ply), + }; + let is_pv = alpha + 1 < beta; if let Some(t) = transposition { if !is_pv && t.depth() >= depth - ply { let (lower, upper) = t.bounds().into_inner(); if lower >= upper || upper <= alpha || lower >= beta { - return Ok(t.pv(ply)); + return Ok(transposed.convert()); } } } - let pv = match transposition { - None => Pv::new(pos.evaluate().saturate(), None), - Some(t) => t.pv(ply), - }; - - let quiesce = ply >= depth; + let quiesce = depth <= ply; let alpha = match quiesce { #[cfg(not(test))] // The stand pat heuristic is not exact. - true => pv.score().max(alpha), + true => transposed.score().max(alpha), _ => alpha, }; if alpha >= beta || ply >= Ply::MAX { - return Ok(pv); - } else if !is_pv && pv.score() - self.rfp(depth, ply) >= beta { + return Ok(transposed.convert()); + } else if !is_pv && transposed.score() - self.rfp(depth, ply) >= beta { #[cfg(not(test))] // The reverse futility pruning heuristic is not exact. - return Ok(pv); + return Ok(transposed.convert()); } else if !is_pv && !pos.is_check() && pos.pieces(pos.turn()).len() > 1 { - if let Some(d) = self.nmp(pv.score(), beta, depth, ply) { + if let Some(d) = self.nmp(transposed.score(), beta, depth, ply) { let mut next = pos.clone(); next.pass(); - if d <= ply || -self.nw(&next, -beta + 1, d, ply + 1, ctrl)? >= beta { + if d <= ply || -self.nw::<0>(&next, -beta + 1, d, ply + 1, ctrl)? >= beta { #[cfg(not(test))] // The null move pruning heuristic is not exact. - return Ok(pv); + return Ok(transposed.convert()); } } } @@ -210,7 +233,7 @@ impl Engine { .filter(|ms| !quiesce || !ms.is_quiet()) .flatten() .map(|m| { - if Some(m) == *pv { + if Some(m) == transposed.moves().next() { (m, Value::upper()) } else if Self::KILLERS.with_borrow(|ks| ks.contains(ply, pos.turn(), m)) { (m, Value::new(25)) @@ -227,21 +250,22 @@ impl Engine { moves.sort_unstable_by_key(|(_, gain)| *gain); - let pv = match moves.pop() { - None => return Ok(pv), + let (head, tail) = match moves.pop() { + None => return Ok(transposed.convert()), Some((m, _)) => { let mut next = pos.clone(); next.play(m); - m >> -self.pvs(&next, -beta..-alpha, depth, ply + 1, ctrl)? + (m, -self.ab(&next, -beta..-alpha, depth, ply + 1, ctrl)?) } }; - if pv >= beta || moves.is_empty() { - return Ok(self.record(pos, bounds, depth, ply, pv)); + if tail >= beta || moves.is_empty() { + self.record(pos, bounds, depth, ply, head, tail.score()); + return Ok(head >> tail); } - let pv = self.driver.drive(pv, &moves, |&best, &(m, gain)| { - let alpha = match best.score() { + let (head, tail) = self.driver.drive(head, tail, &moves, |score, m, gain| { + let alpha = match score { s if s >= beta => return Err(ControlFlow::Break), s => s.max(alpha), }; @@ -253,39 +277,40 @@ impl Engine { if gain < 0 && !pos.is_check() && !next.is_check() { let guess = -next.evaluate(); if let Some(d) = self.lmp(guess, alpha.saturate(), depth, ply) { - if d <= ply || -self.nw(&next, -alpha, d, ply + 1, ctrl)? <= alpha { + if d <= ply || -self.nw::<0>(&next, -alpha, d, ply + 1, ctrl)? <= alpha { #[cfg(not(test))] // The late move pruning heuristic is not exact. - return Ok(best); + return Err(ControlFlow::Continue); } } } - let pv = match -self.nw(&next, -alpha, depth, ply + 1, ctrl)? { - pv if pv <= alpha || pv >= beta => m >> pv, - _ => m >> -self.pvs(&next, -beta..-alpha, depth, ply + 1, ctrl)?, + let partial = match -self.nw(&next, -alpha, depth, ply + 1, ctrl)? { + partial if partial <= alpha || partial >= beta => partial, + _ => -self.ab(&next, -beta..-alpha, depth, ply + 1, ctrl)?, }; - Ok(pv) + Ok(partial) })?; - Ok(self.record(pos, bounds, depth, ply, pv)) + self.record(pos, bounds, depth, ply, head, tail.score()); + Ok(head >> tail) } /// An implementation of [aspiration windows] with [iterative deepening]. /// /// [aspiration windows]: https://www.chessprogramming.org/Aspiration_Windows /// [iterative deepening]: https://www.chessprogramming.org/Iterative_Deepening - fn aw( + fn aw( &self, pos: &Evaluator, limit: Depth, nodes: u64, time: &Range, - interrupter: &Trigger, - ) -> Pv { - let ctrl = Control::Limited(Counter::new(nodes), Timer::new(time.end), interrupter); - let mut pv = Pv::new(Score::new(0), None); + stopper: &Trigger, + ) -> Pv { + let ctrl = Control::Limited(Counter::new(nodes), Timer::new(time.end), stopper); + let mut pv = Pv::new(Score::lower(), []); 'id: for depth in Depth::iter() { let mut overtime = time.end - time.start; @@ -297,21 +322,22 @@ impl Engine { _ => (pv.score() - delta, pv.score() + delta), }; - let ctrl = if pv.is_none() { + const { assert!(N > 0) } + let ctrl = if pv.moves().next().is_none() { &Control::Unlimited } else if depth < limit { &ctrl } else { - break; + break 'id; }; - pv = 'aw: loop { + 'aw: loop { delta = delta.saturating_mul(2); if ctrl.timer().remaining() < Some(overtime) { break 'id; } - let Ok(partial) = self.pvs(pos, lower..upper, draft, Ply::new(0), ctrl) else { + let Ok(partial) = self.ab(pos, lower..upper, draft, Ply::new(0), ctrl) else { break 'id; }; @@ -330,9 +356,12 @@ impl Engine { pv = partial; } - _ => break 'aw partial, + _ => { + pv = partial; + break 'aw; + } } - }; + } } pv @@ -351,9 +380,14 @@ impl Engine { } /// Searches for the [principal variation][`Pv`]. - pub fn search(&mut self, pos: &Evaluator, limits: &Limits, interrupter: &Trigger) -> Pv { + pub fn search( + &mut self, + pos: &Evaluator, + limits: &Limits, + stopper: &Trigger, + ) -> Pv { let time = self.time_to_search(pos, limits); - self.aw(pos, limits.depth(), limits.nodes(), &time, interrupter) + self.aw(pos, limits.depth(), limits.nodes(), &time, stopper) } } @@ -408,7 +442,7 @@ mod tests { #[proptest] #[should_panic] fn nw_panics_if_beta_is_too_small(e: Engine, pos: Evaluator, d: Depth, p: Ply) { - e.nw(&pos, Score::lower(), d, p, &Control::Unlimited)?; + e.nw::<3>(&pos, Score::lower(), d, p, &Control::Unlimited)?; } #[proptest] @@ -423,10 +457,9 @@ mod tests { #[filter(#s.mate().is_none() && #s >= #b)] s: Score, #[map(|s: Selector| s.select(#pos.moves().flatten()))] m: Move, ) { + use Control::Unlimited; e.tt.set(pos.zobrist(), Transposition::lower(d, s, m)); - - let ctrl = Control::Unlimited; - assert_eq!(e.nw(&pos, b, d, p, &ctrl), Ok(Pv::new(s, Some(m)))); + assert_eq!(e.nw::<3>(&pos, b, d, p, &Unlimited), Ok(Pv::new(s, []))); } #[proptest] @@ -441,10 +474,9 @@ mod tests { #[filter(#s.mate().is_none() && #s < #b)] s: Score, #[map(|s: Selector| s.select(#pos.moves().flatten()))] m: Move, ) { + use Control::Unlimited; e.tt.set(pos.zobrist(), Transposition::upper(d, s, m)); - - let ctrl = Control::Unlimited; - assert_eq!(e.nw(&pos, b, d, p, &ctrl), Ok(Pv::new(s, Some(m)))); + assert_eq!(e.nw::<3>(&pos, b, d, p, &Unlimited), Ok(Pv::new(s, []))); } #[proptest] @@ -459,12 +491,9 @@ mod tests { #[filter(#sc.mate().is_none())] sc: Score, #[map(|s: Selector| s.select(#pos.moves().flatten()))] m: Move, ) { + use Control::Unlimited; e.tt.set(pos.zobrist(), Transposition::exact(d, sc, m)); - - assert_eq!( - e.nw(&pos, b, d, p, &Control::Unlimited), - Ok(Pv::new(sc, Some(m))) - ); + assert_eq!(e.nw::<3>(&pos, b, d, p, &Unlimited), Ok(Pv::new(sc, []))); } #[proptest] @@ -476,82 +505,78 @@ mod tests { #[filter(#p > 0)] p: Ply, ) { assert_eq!( - e.nw(&pos, b, d, p, &Control::Unlimited)? < b, + e.nw::<3>(&pos, b, d, p, &Control::Unlimited)? < b, alphabeta(&pos, b - 1..b, d, p) < b ); } #[proptest] #[should_panic] - fn pvs_panics_if_alpha_is_not_greater_than_beta( + fn ab_panics_if_alpha_is_not_greater_than_beta( e: Engine, pos: Evaluator, b: Range, d: Depth, p: Ply, ) { - e.pvs(&pos, b.end..b.start, d, p, &Control::Unlimited)?; + e.ab::<3>(&pos, b.end..b.start, d, p, &Control::Unlimited)?; } #[proptest] - fn pvs_aborts_if_maximum_number_of_nodes_visited( + fn ab_aborts_if_maximum_number_of_nodes_visited( e: Engine, pos: Evaluator, #[filter(!#b.is_empty())] b: Range, d: Depth, p: Ply, ) { - let interrupter = Trigger::armed(); - let ctrl = Control::Limited(Counter::new(0), Timer::infinite(), &interrupter); - assert_eq!(e.pvs(&pos, b, d, p, &ctrl), Err(Interrupted)); + let trigger = Trigger::armed(); + let ctrl = Control::Limited(Counter::new(0), Timer::infinite(), &trigger); + assert_eq!(e.ab::<3>(&pos, b, d, p, &ctrl), Err(Interrupted)); } #[proptest] - fn pvs_aborts_if_time_is_up( + fn ab_aborts_if_time_is_up( e: Engine, pos: Evaluator, #[filter(!#b.is_empty())] b: Range, d: Depth, p: Ply, ) { - let interrupter = Trigger::armed(); - let ctrl = Control::Limited( - Counter::new(u64::MAX), - Timer::new(Duration::ZERO), - &interrupter, - ); + let trigger = Trigger::armed(); + let ctrl = Control::Limited(Counter::new(u64::MAX), Timer::new(Duration::ZERO), &trigger); std::thread::sleep(Duration::from_millis(1)); - assert_eq!(e.pvs(&pos, b, d, p, &ctrl), Err(Interrupted)); + assert_eq!(e.ab::<3>(&pos, b, d, p, &ctrl), Err(Interrupted)); } #[proptest] - fn pvs_aborts_if_interrupter_is_disarmed( + fn ab_aborts_if_stopper_is_disarmed( e: Engine, pos: Evaluator, #[filter(!#b.is_empty())] b: Range, d: Depth, p: Ply, ) { - let interrupter = Trigger::disarmed(); - let ctrl = Control::Limited(Counter::new(u64::MAX), Timer::infinite(), &interrupter); - assert_eq!(e.pvs(&pos, b, d, p, &ctrl), Err(Interrupted)); + let trigger = Trigger::disarmed(); + let ctrl = Control::Limited(Counter::new(u64::MAX), Timer::infinite(), &trigger); + assert_eq!(e.ab::<3>(&pos, b, d, p, &ctrl), Err(Interrupted)); } #[proptest] - fn pvs_returns_static_evaluation_if_max_ply( + fn ab_returns_static_evaluation_if_max_ply( e: Engine, #[filter(#pos.outcome().is_none())] pos: Evaluator, #[filter(!#b.is_empty())] b: Range, d: Depth, ) { assert_eq!( - e.pvs(&pos, b, d, Ply::upper(), &Control::Unlimited), - Ok(Pv::new(pos.evaluate().saturate(), None)) + e.ab::<3>(&pos, b, d, Ply::upper(), &Control::Unlimited), + Ok(Pv::new(pos.evaluate().saturate(), [])) ); } #[proptest] - fn pvs_returns_drawn_score_if_game_ends_in_a_draw( + fn ab_returns_drawn_score_if_game_ends_in_a_draw( #[by_ref] e: Engine, #[filter(#pos.outcome().is_some_and(|o| o.is_draw()))] pos: Evaluator, #[filter(!#b.is_empty())] b: Range, @@ -559,13 +584,13 @@ mod tests { #[filter(#p > 0 || #pos.outcome() != Some(Outcome::DrawByThreefoldRepetition))] p: Ply, ) { assert_eq!( - e.pvs(&pos, b, d, p, &Control::Unlimited), - Ok(Pv::new(Score::new(0), None)) + e.ab::<3>(&pos, b, d, p, &Control::Unlimited), + Ok(Pv::new(Score::new(0), [])) ); } #[proptest] - fn pvs_returns_lost_score_if_game_ends_in_checkmate( + fn ab_returns_lost_score_if_game_ends_in_checkmate( e: Engine, #[filter(#pos.outcome().is_some_and(|o| o.is_decisive()))] pos: Evaluator, #[filter(!#b.is_empty())] b: Range, @@ -573,31 +598,30 @@ mod tests { p: Ply, ) { assert_eq!( - e.pvs(&pos, b, d, p, &Control::Unlimited), - Ok(Pv::new(Score::lower().normalize(p), None)) + e.ab::<3>(&pos, b, d, p, &Control::Unlimited), + Ok(Pv::new(Score::lower().normalize(p), [])) ); } #[proptest] - fn search_finds_the_principal_variation( - mut e: Engine, - pos: Evaluator, - #[filter(#d > 1)] d: Depth, - ) { - let interrupter = Trigger::armed(); + fn search_finds_the_minimax_score(mut e: Engine, pos: Evaluator, #[filter(#d > 1)] d: Depth) { + let trigger = Trigger::armed(); let time = Duration::MAX..Duration::MAX; assert_eq!( - e.search(&pos, &Limits::Depth(d), &interrupter).score(), - e.aw(&pos, d, u64::MAX, &time, &interrupter).score() + e.search::<1>(&pos, &Limits::Depth(d), &trigger).score(), + e.aw::<1>(&pos, d, u64::MAX, &time, &trigger).score() ); } #[proptest] fn search_is_stable(mut e: Engine, pos: Evaluator, d: Depth) { + let limits = Limits::Depth(d); + let trigger = Trigger::armed(); + assert_eq!( - e.search(&pos, &Limits::Depth(d), &Trigger::armed()).score(), - e.search(&pos, &Limits::Depth(d), &Trigger::armed()).score() + e.search::<1>(&pos, &limits, &trigger).score(), + e.search::<1>(&pos, &limits, &trigger).score() ); } @@ -610,7 +634,7 @@ mod tests { let timer = Instant::now(); let trigger = Trigger::armed(); let limits = Limits::Time(Duration::from_millis(ms.into())); - e.search(&pos, &limits, &trigger); + e.search::<3>(&pos, &limits, &trigger); assert!(timer.elapsed() < Duration::from_secs(1)); } @@ -620,7 +644,8 @@ mod tests { #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Duration::ZERO.into(); - assert_ne!(*e.search(&pos, &limits, &Trigger::armed()), None); + let trigger = Trigger::armed(); + assert_ne!(e.search::<3>(&pos, &limits, &trigger).moves().next(), None); } #[proptest] @@ -629,15 +654,17 @@ mod tests { #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Depth::lower().into(); - assert_ne!(*e.search(&pos, &limits, &Trigger::armed()), None); + let trigger = Trigger::armed(); + assert_ne!(e.search::<3>(&pos, &limits, &trigger).moves().next(), None); } #[proptest] - fn search_ignores_interrupter_to_find_some_pv( + fn search_ignores_stopper_to_find_some_pv( mut e: Engine, #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Limits::None; - assert_ne!(*e.search(&pos, &limits, &Trigger::disarmed()), None); + let trigger = Trigger::armed(); + assert_ne!(e.search::<3>(&pos, &limits, &trigger).moves().next(), None); } } diff --git a/lib/search/pv.rs b/lib/search/pv.rs index 1d64f050..543fbc7d 100644 --- a/lib/search/pv.rs +++ b/lib/search/pv.rs @@ -1,42 +1,74 @@ use crate::{chess::Move, search::Score}; -use derive_more::{Constructor, Deref}; use std::cmp::Ordering; +use std::fmt::{self, Display, Formatter, Write}; use std::ops::{Neg, Shr}; +#[cfg(test)] +use proptest::{collection::vec, prelude::*}; + /// The [principal variation]. /// /// [principal variation]: https://www.chessprogramming.org/Principal_Variation -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Constructor, Deref)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] -pub struct Pv { +pub struct Pv { score: Score, - #[deref] - r#move: Option, + #[cfg_attr(test, strategy(vec(any::(), ..=N).prop_map(|ms| { + let mut moves = [None; N]; + for (m, n) in moves.iter_mut().zip(ms) { + *m = Some(n); + } + moves + })))] + moves: [Option; N], } -impl Pv { +impl Pv { + /// Constructs a [`Pv`]. + #[inline(always)] + pub fn new>(score: Score, ms: I) -> Self { + let mut moves = [None; N]; + for (m, n) in moves.iter_mut().zip(ms) { + *m = Some(n); + } + + Pv { score, moves } + } + /// The score from the point of view of the side to move. #[inline(always)] pub fn score(&self) -> Score { self.score } + + /// The sequence of [`Move`]s in this principal variation. + #[inline(always)] + pub fn moves(&self) -> impl Iterator + '_ { + self.moves.iter().map_while(|m| *m) + } + + /// Converts to a principal variation of a different length. + #[inline(always)] + pub fn convert(self) -> Pv { + Pv::new(self.score(), self.moves()) + } } -impl Ord for Pv { +impl Ord for Pv { #[inline(always)] fn cmp(&self, other: &Self) -> Ordering { self.score.cmp(&other.score) } } -impl PartialOrd for Pv { +impl PartialOrd for Pv { #[inline(always)] fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl PartialEq for Pv +impl PartialEq for Pv where Score: PartialEq, { @@ -46,7 +78,7 @@ where } } -impl PartialOrd for Pv +impl PartialOrd for Pv where Score: PartialOrd, { @@ -56,7 +88,7 @@ where } } -impl Neg for Pv { +impl Neg for Pv { type Output = Self; #[inline(always)] @@ -66,48 +98,73 @@ impl Neg for Pv { } } -impl Shr for Move { - type Output = Pv; +impl Shr> for Move { + type Output = Pv; #[inline(always)] - fn shr(self, mut pv: Pv) -> Self::Output { - pv.r#move = Some(self); + fn shr(self, mut pv: Pv) -> Self::Output { + if N > 0 { + pv.moves.copy_within(..N - 1, 1); + pv.moves[0] = Some(self); + } + pv } } +impl Display for Pv { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut moves = self.moves(); + let Some(head) = moves.next() else { + return Ok(()); + }; + + Display::fmt(&head, f)?; + + for m in moves { + f.write_char(' ')?; + Display::fmt(&m, f)?; + } + + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; use test_strategy::proptest; #[proptest] - fn score_returns_score(pv: Pv) { + fn score_returns_score(pv: Pv<3>) { assert_eq!(pv.score(), pv.score); } #[proptest] - fn negation_changes_score(pv: Pv) { - assert_eq!(pv.neg().score(), -pv.score()); + fn negation_changes_score(pv: Pv<3>) { + assert_eq!(pv.clone().neg().score(), -pv.score()); } #[proptest] - fn negation_preserves_best(pv: Pv) { - assert_eq!(*pv.neg(), *pv); + fn negation_preserves_moves(pv: Pv<3>) { + assert_eq!( + pv.moves().collect::>(), + pv.neg().moves().collect::>() + ); } #[proptest] - fn shift_changes_best(pv: Pv, m: Move) { - assert_eq!(*m.shr(pv), Some(m)); + fn shift_changes_moves(pv: Pv<3>, m: Move) { + assert_eq!(m.shr(pv).moves().next(), Some(m)); } #[proptest] - fn shift_preserves_score(pv: Pv, m: Move) { - assert_eq!(m.shr(pv).score(), pv.score()); + fn shift_preserves_score(pv: Pv<3>, m: Move) { + assert_eq!(m.shr(pv.clone()).score(), pv.score()); } #[proptest] - fn pv_with_larger_score_is_larger(p: Pv, #[filter(#p.score() != #q.score())] q: Pv) { + fn pv_with_larger_score_is_larger(p: Pv<3>, #[filter(#p.score() != #q.score())] q: Pv<3>) { assert_eq!(p < q, p.score() < q.score()); } } diff --git a/lib/search/transposition.rs b/lib/search/transposition.rs index 215c4bdb..f468b304 100644 --- a/lib/search/transposition.rs +++ b/lib/search/transposition.rs @@ -104,8 +104,8 @@ impl Transposition { /// Principal variation normalized to [`Ply`]. #[inline(always)] - pub fn pv(&self, ply: Ply) -> Pv { - Pv::new(self.score().normalize(ply), Some(self.best)) + pub fn transpose(&self, ply: Ply) -> Pv<1> { + Pv::new(self.score().normalize(ply), [self.best]) } } diff --git a/lib/uci.rs b/lib/uci.rs index b01a7abc..9e8989cd 100644 --- a/lib/uci.rs +++ b/lib/uci.rs @@ -1,6 +1,6 @@ use crate::chess::{Color, Move, Perspective}; use crate::nnue::Evaluator; -use crate::search::{Engine, HashSize, Limits, Options, ThreadCount}; +use crate::search::{Depth, Engine, HashSize, Limits, Options, ThreadCount}; use crate::util::{Assume, Integer, Trigger}; use futures::channel::oneshot::channel as oneshot; use futures::{future::FusedFuture, prelude::*, select_biased as select, stream::FusedStream}; @@ -71,11 +71,11 @@ impl Uci { } impl + Unpin, O: Sink + Unpin> Uci { - async fn go(&mut self, limits: &Limits) -> Result<(), O::Error> { - let interrupter = Trigger::armed(); + async fn go(&mut self, limits: &Limits) -> Result<(), O::Error> { + let stopper = Trigger::armed(); let mut search = - unsafe { unblock(|| self.engine.search(&self.position, limits, &interrupter)) }; + unsafe { unblock(|| self.engine.search::(&self.position, limits, &stopper)) }; let pv = loop { select! { @@ -83,7 +83,7 @@ impl + Unpin, O: Sink + Unpin> Uci { line = self.input.next() => { match line.as_deref().map(str::trim) { None => break search.await, - Some("stop") => { interrupter.disarm(); }, + Some("stop") => { stopper.disarm(); }, Some(cmd) => eprintln!("ignored unsupported command `{cmd}` during search"), } } @@ -91,14 +91,15 @@ impl + Unpin, O: Sink + Unpin> Uci { }; let info = match pv.score().mate() { - Some(p) if p > 0 => format!("info score mate {}", (p + 1) / 2), - Some(p) => format!("info score mate {}", (p - 1) / 2), - None => format!("info score cp {:+}", pv.score()), + Some(p) if p > 0 => format!("info score mate {} pv {pv}", (p + 1) / 2), + Some(p) => format!("info score mate {} pv {pv}", (p - 1) / 2), + None => format!("info score cp {:+} pv {pv}", pv.score()), }; self.output.send(info).await?; - if let Some(m) = *pv { + const { assert!(N > 0) } + if let Some(m) = pv.moves().next() { self.output.send(format!("bestmove {m}")).await?; } @@ -106,9 +107,9 @@ impl + Unpin, O: Sink + Unpin> Uci { } async fn bench(&mut self, limits: &Limits) -> Result<(), O::Error> { - let interrupter = Trigger::armed(); + let stopper = Trigger::armed(); let timer = Instant::now(); - self.engine.search(&self.position, limits, &interrupter); + self.engine.search::<1>(&self.position, limits, &stopper); let millis = timer.elapsed().as_millis(); let info = match limits { @@ -142,27 +143,29 @@ impl + Unpin, O: Sink + Unpin> Uci { (Ok(t), Ok(i)) => { let t = Duration::from_millis(t); let i = Duration::from_millis(i); - self.go(&Limits::Clock(t, i)).await?; + self.go::<1>(&Limits::Clock(t, i)).await?; } } } - ["go", "depth", depth] => match depth.parse() { - Ok(d) => self.go(&Limits::Depth(d)).await?, + ["go", "movetime", time] => match time.parse() { + Ok(ms) => self.go::<{ Depth::MAX as _ }>(&Duration::from_millis(ms).into()).await?, Err(e) => eprintln!("{e}"), }, - ["go", "nodes", nodes] => match nodes.parse() { - Ok(n) => self.go(&Limits::Nodes(n)).await?, + ["go", "depth", depth] => match depth.parse() { + Ok(d) => self.go::<{ Depth::MAX as _ }>(&Limits::Depth(d)).await?, Err(e) => eprintln!("{e}"), }, - ["go", "movetime", time] => match time.parse() { - Ok(ms) => self.go(&Duration::from_millis(ms).into()).await?, + ["go", "nodes", nodes] => match nodes.parse() { + Ok(n) => self.go::<{ Depth::MAX as _ }>(&Limits::Nodes(n)).await?, Err(e) => eprintln!("{e}"), }, - ["go"] | ["go", "infinite"] => self.go(&Limits::None).await?, + ["go"] | ["go", "infinite"] => { + self.go::<{ Depth::MAX as _ }>(&Limits::None).await? + } ["bench", "depth", depth] => match depth.parse() { Ok(d) => self.bench(&Limits::Depth(d)).await?,