From 143d9770f36f1fd6af40f8f53cdb391f0dc03343 Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Sun, 17 Nov 2024 21:20:56 +0100 Subject: [PATCH 1/3] impl Deref> for Pv --- lib/search/driver.rs | 18 +++++++++--------- lib/search/engine.rs | 34 +++++++++++++++++----------------- lib/search/pv.rs | 24 +++++++----------------- lib/search/transposition.rs | 8 ++++---- lib/uci.rs | 2 +- 5 files changed, 38 insertions(+), 48 deletions(-) diff --git a/lib/search/driver.rs b/lib/search/driver.rs index 4754a616..05193576 100644 --- a/lib/search/driver.rs +++ b/lib/search/driver.rs @@ -32,7 +32,7 @@ impl Driver { /// /// The order in which elements are processed and on which thread is unspecified. #[inline(always)] - pub fn drive(&self, mut best: Pv, moves: &[M], f: F) -> Result + pub fn drive(&self, mut pv: Pv, moves: &[M], f: F) -> Result where M: Sync, F: Fn(&Pv, &M) -> Result + Sync, @@ -40,27 +40,27 @@ impl Driver { match self { Self::Sequential => { for m in moves.iter().rev() { - best = match f(&best, m) { - Ok(pv) => pv.max(best), + pv = match f(&pv, m) { + Ok(partial) => partial.max(pv), Err(ControlFlow::Break) => break, Err(ControlFlow::Interrupt(e)) => return Err(e), }; } - Ok(best) + Ok(pv) } Self::Parallel(e) => e.install(|| { use Ordering::Relaxed; - let best = AtomicU64::new(IndexedPv(best, u32::MAX).encode().get()); + let pv = AtomicU64::new(IndexedPv(pv, u32::MAX).encode().get()); let result = moves.par_iter().enumerate().rev().try_for_each(|(idx, m)| { - let pv = f(&IndexedPv::decode(Bits::new(best.load(Relaxed))), m)?; - best.fetch_max(IndexedPv(pv, idx.saturate()).encode().get(), Relaxed); + let partial = f(&IndexedPv::decode(Bits::new(pv.load(Relaxed))), m)?; + pv.fetch_max(IndexedPv(partial, idx.saturate()).encode().get(), Relaxed); Ok(()) }); if matches!(result, Ok(()) | Err(ControlFlow::Break)) { - Ok(*IndexedPv::decode(Bits::new(best.into_inner()))) + Ok(*IndexedPv::decode(Bits::new(pv.into_inner()))) } else { Err(Interrupted) } @@ -81,7 +81,7 @@ impl Binary for IndexedPv { let mut bits = Bits::default(); bits.push(self.score().encode()); bits.push(Bits::::new(self.1)); - bits.push(self.best().encode()); + bits.push(self.deref().encode()); bits } diff --git a/lib/search/engine.rs b/lib/search/engine.rs index 2cb52110..8ea43163 100644 --- a/lib/search/engine.rs +++ b/lib/search/engine.rs @@ -45,7 +45,7 @@ impl Engine { /// Records a `[Transposition`]. fn record(&self, pos: &Evaluator, bounds: Range, depth: Depth, ply: Ply, pv: Pv) -> Pv { - let m = pv.best().assume(); + let m = pv.assume(); if pv >= bounds.end && m.is_quiet() { Self::KILLERS.with_borrow_mut(|ks| ks.insert(ply, pos.turn(), m)); } @@ -177,38 +177,38 @@ impl Engine { if !is_pv && t.depth() >= depth - ply { let (lower, upper) = t.bounds().into_inner(); if lower >= upper || upper <= alpha || lower >= beta { - return Ok(Pv::new(t.score().normalize(ply), Some(t.best()))); + return Ok(t.pv(ply)); } } } - let score = match transposition { - Some(t) => t.score().normalize(ply), - _ => pos.evaluate().saturate(), + let pv = match transposition { + None => Pv::new(pos.evaluate().saturate(), None), + Some(t) => t.pv(ply), }; let quiesce = ply >= depth; let alpha = match quiesce { #[cfg(not(test))] // The stand pat heuristic is not exact. - true => alpha.max(score), + true => pv.score().max(alpha), _ => alpha, }; if alpha >= beta || ply >= Ply::MAX { - return Ok(Pv::new(score, None)); - } else if score - self.rfp(depth, ply) >= beta { + return Ok(pv); + } else if pv.score() - self.rfp(depth, ply) >= beta { #[cfg(not(test))] // The reverse futility pruning heuristic is not exact. - return Ok(Pv::new(score, None)); + return Ok(pv); } else if !is_pv && !pos.is_check() && pos.pieces(pos.turn()).len() > 1 { - if let Some(d) = self.nmp(score, beta, depth, ply) { + if let Some(d) = self.nmp(pv.score(), beta, depth, ply) { let mut next = pos.clone(); next.pass(); if d <= ply || -self.nw(&next, -beta + 1, d, ply + 1, ctrl)? >= beta { #[cfg(not(test))] // The null move pruning heuristic is not exact. - return Ok(Pv::new(score, None)); + return Ok(pv); } } } @@ -218,7 +218,7 @@ impl Engine { .filter(|ms| !quiesce || !ms.is_quiet()) .flatten() .map(|m| { - if Some(m) == transposition.map(|t| t.best()) { + if Some(m) == *pv { (m, Value::upper()) } else if Self::KILLERS.with_borrow(|ks| ks.contains(ply, pos.turn(), m)) { (m, Value::new(25)) @@ -236,7 +236,7 @@ impl Engine { moves.sort_unstable_by_key(|(_, gain)| *gain); let pv = match moves.pop() { - None => return Ok(Pv::new(score, None)), + None => return Ok(pv), Some((m, _)) => { let mut next = pos.clone(); next.play(m); @@ -300,7 +300,7 @@ impl Engine { use Control::*; pv = self.fw(pos, depth, Ply::new(0), &Unlimited).assume(); depth = depth + 1; - if pv.best().is_some() { + if pv.is_some() { break; } } @@ -653,7 +653,7 @@ mod tests { #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Duration::ZERO.into(); - assert_ne!(e.search(&pos, &limits, &Trigger::armed()).best(), None); + assert_ne!(*e.search(&pos, &limits, &Trigger::armed()), None); } #[proptest] @@ -662,7 +662,7 @@ mod tests { #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Depth::lower().into(); - assert_ne!(e.search(&pos, &limits, &Trigger::armed()).best(), None); + assert_ne!(*e.search(&pos, &limits, &Trigger::armed()), None); } #[proptest] @@ -671,6 +671,6 @@ mod tests { #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Limits::None; - assert_ne!(e.search(&pos, &limits, &Trigger::disarmed()).best(), None); + assert_ne!(*e.search(&pos, &limits, &Trigger::disarmed()), None); } } diff --git a/lib/search/pv.rs b/lib/search/pv.rs index 7ff7964a..1d64f050 100644 --- a/lib/search/pv.rs +++ b/lib/search/pv.rs @@ -1,35 +1,25 @@ use crate::{chess::Move, search::Score}; +use derive_more::{Constructor, Deref}; use std::cmp::Ordering; use std::ops::{Neg, Shr}; /// The [principal variation]. /// /// [principal variation]: https://www.chessprogramming.org/Principal_Variation -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Constructor, Deref)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] pub struct Pv { score: Score, - best: Option, + #[deref] + r#move: Option, } impl Pv { - /// Constructs a pv. - #[inline(always)] - pub fn new(score: Score, best: Option) -> Self { - Pv { score, best } - } - /// The score from the point of view of the side to move. #[inline(always)] pub fn score(&self) -> Score { self.score } - - /// An iterator over [`Move`]s in this principal variation. - #[inline(always)] - pub fn best(&self) -> Option { - self.best - } } impl Ord for Pv { @@ -81,7 +71,7 @@ impl Shr for Move { #[inline(always)] fn shr(self, mut pv: Pv) -> Self::Output { - pv.best = Some(self); + pv.r#move = Some(self); pv } } @@ -103,12 +93,12 @@ mod tests { #[proptest] fn negation_preserves_best(pv: Pv) { - assert_eq!(pv.neg().best(), pv.best()); + assert_eq!(*pv.neg(), *pv); } #[proptest] fn shift_changes_best(pv: Pv, m: Move) { - assert_eq!(m.shr(pv).best(), Some(m)); + assert_eq!(*m.shr(pv), Some(m)); } #[proptest] diff --git a/lib/search/transposition.rs b/lib/search/transposition.rs index dfd7e26e..215c4bdb 100644 --- a/lib/search/transposition.rs +++ b/lib/search/transposition.rs @@ -1,5 +1,5 @@ use crate::chess::{Move, Zobrist}; -use crate::search::{Depth, HashSize, Score}; +use crate::search::{Depth, HashSize, Ply, Pv, Score}; use crate::util::{Assume, Binary, Bits, Integer}; use derive_more::Debug; use std::mem::size_of; @@ -102,10 +102,10 @@ impl Transposition { self.score } - /// Best [`Move`] at this depth. + /// Principal variation normalized to [`Ply`]. #[inline(always)] - pub fn best(&self) -> Move { - self.best + pub fn pv(&self, ply: Ply) -> Pv { + Pv::new(self.score().normalize(ply), Some(self.best)) } } diff --git a/lib/uci.rs b/lib/uci.rs index be404194..b01a7abc 100644 --- a/lib/uci.rs +++ b/lib/uci.rs @@ -98,7 +98,7 @@ impl + Unpin, O: Sink + Unpin> Uci { self.output.send(info).await?; - if let Some(m) = pv.best() { + if let Some(m) = *pv { self.output.send(format!("bestmove {m}")).await?; } From 975efe015bc1fce67251cdc0057766015cd9c8e1 Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Sun, 17 Nov 2024 18:21:37 +0100 Subject: [PATCH 2/3] do not reduce or extend at the root --- lib/search/engine.rs | 111 +++++++++++++++---------------------------- 1 file changed, 39 insertions(+), 72 deletions(-) diff --git a/lib/search/engine.rs b/lib/search/engine.rs index 8ea43163..e01654b4 100644 --- a/lib/search/engine.rs +++ b/lib/search/engine.rs @@ -1,6 +1,6 @@ use crate::nnue::{Evaluator, Value}; -use crate::search::*; use crate::util::{Assume, Counter, Integer, Timer, Trigger}; +use crate::{chess::Outcome, search::*}; use arrayvec::ArrayVec; use std::{cell::RefCell, ops::Range, time::Duration}; @@ -108,17 +108,6 @@ impl Engine { Some(depth - r - (depth - ply) / 4) } - /// A full alpha-beta search. - fn fw( - &self, - pos: &Evaluator, - depth: Depth, - ply: Ply, - ctrl: &Control, - ) -> Result { - self.pvs(pos, Score::lower()..Score::upper(), depth, ply, ctrl) - } - /// A [zero-window] alpha-beta search. /// /// [zero-window]: https://www.chessprogramming.org/Null_Window @@ -148,19 +137,22 @@ impl Engine { debug_assert!(!bounds.is_empty()); ctrl.interrupted()?; + let is_root = ply == 0; + let (alpha, beta) = match pos.outcome() { + None => self.mdp(ply, &bounds), + Some(Outcome::DrawByThreefoldRepetition) if is_root => self.mdp(ply, &bounds), + Some(o) if o.is_draw() => return Ok(Pv::new(Score::new(0), None)), + Some(_) => return Ok(Pv::new(Score::lower().normalize(ply), None)), + }; - let (alpha, beta) = self.mdp(ply, &bounds); if alpha >= beta { return Ok(Pv::new(alpha, None)); } - let transposition = match pos.outcome() { - Some(o) if o.is_draw() => return Ok(Pv::new(Score::new(0), None)), - Some(_) => return Ok(Pv::new(Score::lower().normalize(ply), None)), - None => self.tt.get(pos.zobrist()), - }; - + let transposition = self.tt.get(pos.zobrist()); let depth = match transposition { + _ if is_root => depth, + #[cfg(not(test))] // Extensions are not exact. Some(_) if pos.is_check() => depth + 1, @@ -294,48 +286,46 @@ impl Engine { ) -> Pv { let ctrl = Control::Limited(Counter::new(nodes), Timer::new(time.end), interrupter); let mut pv = Pv::new(Score::new(0), None); - let mut depth = Depth::new(0); - - while depth < Depth::upper() { - use Control::*; - pv = self.fw(pos, depth, Ply::new(0), &Unlimited).assume(); - depth = depth + 1; - if pv.is_some() { - break; - } - } - 'id: for d in depth.get()..=limit.get() { + 'id: for depth in Depth::iter() { let mut overtime = time.end - time.start; - let mut depth = Depth::new(d); - let mut delta: i16 = 5; + let mut draft = depth; + let mut delta = 5i16; - let (mut lower, mut upper) = match d { + let (mut lower, mut upper) = match depth.get() { ..=4 => (Score::lower(), Score::upper()), _ => (pv.score() - delta, pv.score() + delta), }; + let ctrl = if pv.is_none() { + &Control::Unlimited + } else if depth < limit { + &ctrl + } else { + break; + }; + pv = 'aw: loop { delta = delta.saturating_mul(2); if ctrl.timer().remaining() < Some(overtime) { break 'id; } - let Ok(partial) = self.pvs(pos, lower..upper, depth, Ply::new(0), &ctrl) else { + let Ok(partial) = self.pvs(pos, lower..upper, draft, Ply::new(0), ctrl) else { break 'id; }; match partial.score() { score if (-lower..Score::upper()).contains(&-score) => { overtime /= 2; - depth = Depth::new(d); + draft = depth; upper = lower / 2 + upper / 2; lower = score - delta; } score if (upper..Score::upper()).contains(&score) => { overtime = time.end - time.start; - depth = depth - 1; + draft = draft - 1; upper = score + delta; pv = partial; } @@ -408,10 +398,6 @@ mod tests { alpha } - fn negamax(pos: &Evaluator, depth: Depth, ply: Ply) -> Score { - alphabeta(pos, Score::lower()..Score::upper(), depth, ply) - } - #[proptest] fn hash_is_an_upper_limit_for_table_size(o: Options) { let e = Engine::with_options(&o); @@ -475,17 +461,19 @@ mod tests { ) { e.tt.set(pos.zobrist(), Transposition::exact(d, sc, m)); - let ctrl = Control::Unlimited; - assert_eq!(e.nw(&pos, b, d, p, &ctrl), Ok(Pv::new(sc, Some(m)))); + assert_eq!( + e.nw(&pos, b, d, p, &Control::Unlimited), + Ok(Pv::new(sc, Some(m))) + ); } #[proptest] fn nw_finds_score_bound( - e: Engine, + #[by_ref] e: Engine, pos: Evaluator, #[filter((Value::lower()..Value::upper()).contains(&#b))] b: Score, d: Depth, - #[filter(#p >= 0)] p: Ply, + #[filter(#p > 0)] p: Ply, ) { assert_eq!( e.nw(&pos, b, d, p, &Control::Unlimited)? < b, @@ -564,11 +552,11 @@ mod tests { #[proptest] fn pvs_returns_drawn_score_if_game_ends_in_a_draw( - e: Engine, + #[by_ref] e: Engine, #[filter(#pos.outcome().is_some_and(|o| o.is_draw()))] pos: Evaluator, #[filter(!#b.is_empty())] b: Range, d: Depth, - p: Ply, + #[filter(#p > 0 || #pos.outcome() != Some(Outcome::DrawByThreefoldRepetition))] p: Ply, ) { assert_eq!( e.pvs(&pos, b, d, p, &Control::Unlimited), @@ -590,39 +578,18 @@ mod tests { ); } - #[proptest] - fn fw_finds_best_score(e: Engine, pos: Evaluator, d: Depth, #[filter(#p >= 0)] p: Ply) { - assert_eq!(e.fw(&pos, d, p, &Control::Unlimited)?, negamax(&pos, d, p)); - } - - #[proptest] - fn fw_does_not_depend_on_configuration( - x: Options, - y: Options, - pos: Evaluator, - d: Depth, - #[filter(#p >= 0)] p: Ply, - ) { - let x = Engine::with_options(&x); - let y = Engine::with_options(&y); - - let ctrl = Control::Unlimited; - - assert_eq!( - x.fw(&pos, d, p, &ctrl)?.score(), - y.fw(&pos, d, p, &ctrl)?.score() - ); - } - #[proptest] fn search_finds_the_principal_variation( mut e: Engine, pos: Evaluator, #[filter(#d > 1)] d: Depth, ) { + let interrupter = Trigger::armed(); + let time = Duration::MAX..Duration::MAX; + assert_eq!( - e.search(&pos, &Limits::Depth(d), &Trigger::armed()).score(), - e.fw(&pos, d, Ply::new(0), &Control::Unlimited)?.score() + e.search(&pos, &Limits::Depth(d), &interrupter).score(), + e.aw(&pos, d, u64::MAX, &time, &interrupter).score() ); } From 65c78f59370d508f8abf00b1c83d590aa46764f3 Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Thu, 21 Nov 2024 00:17:15 +0100 Subject: [PATCH 3/3] fix latest clippy warnings --- Cargo.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b6bed774..986c9af1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,10 @@ readme = "README.md" keywords = ["chess"] [lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage)'] } +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(coverage)', + 'cfg(feature, values("used_linker"))', +] } [dependencies] arrayvec = { version = "0.7.6", default-features = false, features = ["std"] }