Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transposition table! #6

Merged
merged 11 commits into from
Jul 7, 2024
5 changes: 5 additions & 0 deletions src/chess.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ inline constexpr int max_ply = 200;
inline constexpr int max_game_ply = max_ply + 400;
inline constexpr int max_moves = 256;

inline constexpr score score_infinite = 32000;
inline constexpr score score_mate = 31500;
inline constexpr score score_win = score_mate - max_ply;
inline constexpr score score_none = 32001;

} // namespace constants

enum class file : u8 {
Expand Down
1 change: 1 addition & 0 deletions src/search/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ void search::bench::run(searcher& searcher, const u32 depth) {
for (const auto& fen : bench_fens) {
board::position pos(fen);

searcher.set_start_time(utils::time::get_time_ms());
searcher.main_search(pos);

total_nodes += searcher.searched_nodes();
Expand Down
96 changes: 70 additions & 26 deletions src/search/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "../eval/eval.hpp"
#include "../moves/movegen.hpp"
#include "../utils/time.hpp"
#include "tt.hpp"

namespace search {

Expand All @@ -21,6 +22,8 @@ void searcher::set_limits(const u64 nodes_limit, const u64 time_limit, const u32
m_limits.depth_limit = depth_limit;
}

void searcher::set_start_time(const u64 time) { m_timer.set_start_time(time); }

void searcher::parse_time_control(const std::vector<std::string>& command, const color stm) {
u64 base_time{};
u16 increment{};
Expand All @@ -46,16 +49,14 @@ void searcher::parse_time_control(const std::vector<std::string>& command, const
m_timer = time_manager(utils::time::get_time_ms(), base_time, increment);
}

/// @brief Main entrypoint for the search function
void searcher::main_search(const board::position& pos) {
m_timer.set_start_time(utils::time::get_time_ms());
reset_info();
auto best_move = moves::move::null();

// Iterative deepening loop
for (int current_depth = 1; current_depth <= m_limits.depth_limit; ++current_depth) {
const score best_score =
negamax(pos, -score_infinite, score_infinite, current_depth, 0, m_info.pv);
const score best_score = negamax(pos, -constants::score_infinite, constants::score_infinite,
current_depth, 0, m_info.pv);

if (m_info.stopped) {
// If search stopped too early and we don't have a best move, we update it in order to
Expand All @@ -75,7 +76,6 @@ void searcher::main_search(const board::position& pos) {
std::cout << std::format("bestmove {}", best_move.to_string()) << std::endl;
}

/// @brief Quiescence search, to get rid of the horizon effect
score searcher::qsearch(const board::position& pos, score alpha, const score beta, const int ply) {
++m_info.searched_nodes;

Expand All @@ -87,6 +87,14 @@ score searcher::qsearch(const board::position& pos, score alpha, const score bet
return 0;
}

tt::tt_entry entry;

const bool tt_hit = tt::global_tt.probe(pos.key(), entry);
const auto tt_score = tt_hit ? tt::score_from_tt(entry.value(), ply) : constants::score_none;

if (tt_score != constants::score_none && entry.can_use_score(alpha, beta))
return tt_score;

const score static_eval = eval::evaluate(pos);

if (ply >= constants::max_ply)
Expand All @@ -98,7 +106,9 @@ score searcher::qsearch(const board::position& pos, score alpha, const score bet
if (static_eval > alpha)
alpha = static_eval;

score best_score = static_eval;
score best_score = static_eval;
auto best_move = moves::move::null();

moves::move_list move_list;
generate_all_captures(pos, move_list);
move_list.score_moves(pos);
Expand All @@ -118,19 +128,28 @@ score searcher::qsearch(const board::position& pos, score alpha, const score bet
if (m_info.stopped)
return 0;

if (current_score > best_score)
if (current_score > best_score) {
best_score = current_score;

if (current_score > alpha)
alpha = current_score;
if (current_score > alpha) {
alpha = current_score;
best_move = current_move;

if (alpha >= beta)
break;
if (alpha >= beta)
break;
}
}
}

const auto tt_flag = best_score >= beta ? tt::tt_entry::tt_flag::lower_bound
: tt::tt_entry::tt_flag::upper_bound;

tt::global_tt.store(pos.key(), tt::tt_entry(pos.key(), best_move,
tt::score_to_tt(best_score, ply), 0, tt_flag));

return best_score;
}

/// @brief Fail-soft negamax algorithm with alpha-beta pruning
score searcher::negamax(const board::position& pos,
score alpha,
const score beta,
Expand All @@ -151,15 +170,31 @@ score searcher::negamax(const board::position& pos,
if (depth <= 0)
return qsearch(pos, alpha, beta, ply);

if (ply > 0 && pos.has_repeated())
const bool root_node = ply == 0;

if (!root_node && pos.has_repeated())
return 0;

if (ply >= constants::max_ply)
return eval::evaluate(pos);

u16 legal_moves{};
pv_line child_pv{};
score best_score = -score_infinite;
tt::tt_entry entry;

const bool tt_hit = tt::global_tt.probe(pos.key(), entry);
const auto tt_score = tt_hit ? tt::score_from_tt(entry.value(), ply) : constants::score_none;
const u8 tt_depth = entry.depth();

if (!root_node && tt_score != constants::score_none && tt_depth >= depth
&& entry.can_use_score(alpha, beta))
return tt_score;

u16 legal_moves{};
pv_line child_pv{};

score best_score = -constants::score_infinite;
auto best_move = moves::move::null();
const score original_alpha = alpha;

moves::move_list move_list;
generate_all_moves(pos, move_list);
move_list.score_moves(pos);
Expand All @@ -178,27 +213,36 @@ score searcher::negamax(const board::position& pos,

const score current_score = -negamax(copy, -beta, -alpha, depth - 1, ply + 1, child_pv);

if (current_score > best_score)
if (current_score > best_score) {
best_score = current_score;

if (current_score > alpha) {
alpha = current_score;
pv.update(current_move, child_pv);
if (current_score > alpha) {
alpha = current_score;
best_move = current_move;
pv.update(current_move, child_pv);

if (alpha >= beta)
break;
}
}

// Double-check if search stopped to make sure we don't exceed the search limits
if (m_info.stopped)
return 0;

if (alpha >= beta)
break;
}

// Checkmate / stalemate detection
if (!legal_moves) {
return pos.checkers().bit_count() > 0 ? -score_mate + ply : 0;
return pos.checkers().bit_count() > 0 ? -constants::score_mate + ply : 0;
}

const auto tt_flag = best_score <= original_alpha ? tt::tt_entry::tt_flag::upper_bound
: best_score >= beta ? tt::tt_entry::tt_flag::lower_bound
: tt::tt_entry::tt_flag::exact;

tt::global_tt.store(pos.key(), tt::tt_entry(pos.key(), best_move,
tt::score_to_tt(best_score, ply), depth, tt_flag));

return best_score;
}

Expand All @@ -213,8 +257,8 @@ bool searcher::should_stop() const {
}

void searcher::report_info(u64 elapsed, int depth, score score, const pv_line& pv) const {
std::cout << std::format("info depth {} score cp {} time {} nodes {} nps {} pv{}", depth,
score, elapsed, m_info.searched_nodes,
std::cout << std::format("info depth {} score cp {} time {} nodes {} nps {} pv{}", depth, score,
elapsed, m_info.searched_nodes,
m_info.searched_nodes / std::max<u64>(1, elapsed) * 1000,
pv.to_string())
<< std::endl;
Expand Down
31 changes: 28 additions & 3 deletions src/search/search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

namespace search {

constexpr score score_infinite = 32000;
constexpr score score_mate = 31500;

struct pv_line {
std::array<moves::move, constants::max_moves> moves{};
usize length{};
Expand Down Expand Up @@ -59,20 +56,48 @@ class searcher {

void reset_info();
void set_limits(u64 nodes_limit, u64 time_limit, u32 depth_limit);
void set_start_time(u64 time);
void parse_time_control(const std::vector<std::string>& command, color stm);

/// @brief Main entrypoint for the search function
/// @param pos Position to search from
void main_search(const board::position& pos);

private:
search_info m_info{};
search_limits m_limits{};
time_manager m_timer{};

/// @brief Quiescence search, to get rid of the horizon effect
/// @param pos Position to search from
/// @param alpha Best score for the maximizing player
/// @param beta Best score for the minimizing player
/// @param ply Internal depth of the search tree (seldepth)
/// @returns The best score found
/// @note See https://en.wikipedia.org/wiki/Quiescence_search for reference
score qsearch(const board::position& pos, score alpha, score beta, int ply);

/// @brief Fail-soft negamax algorithm with alpha-beta pruning
/// @param pos Position to search from
/// @param alpha Best score for the maximizing player
/// @param beta Best score for the minimizing player
/// @param depth Depth to start searching from
/// @param ply Internal depth of the search tree (seldepth)
/// @param pv PV-List on the stack, to keep track of the principal variation
/// @returns The best score found
/// @note See https://en.wikipedia.org/wiki/Negamax for reference
score negamax(
const board::position& pos, score alpha, score beta, int depth, int ply, pv_line& pv);

/// @brief Determines if search should stop according to the search limits
/// @returns true if time is up or the nodes or time limit is exceeded
[[nodiscard]] bool should_stop() const;

/// @brief Reports uci-compliant info about the search tree
/// @param elapsed Elapsed time since the start of the search, in milliseconds
/// @param depth Depth of the search tree
/// @param score Score of the current position
/// @param pv Principal variation line from the current position
void report_info(u64 elapsed, int depth, score score, const pv_line& pv) const;
};

Expand Down
49 changes: 48 additions & 1 deletion src/search/tt.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,50 @@
#include "tt.hpp"

namespace search::tt {}
#include <algorithm>

namespace search::tt {

bool transposition_table::probe(const zobrist_key key, tt_entry& entry) const {
if (const auto current_entry = m_data[index(key)]; current_entry.key_matches(key)) {
entry = current_entry;

return true;
}

return false;
}

void transposition_table::clear() { std::ranges::fill(m_data, tt_entry{}); }

void transposition_table::resize(const usize size_mb) {
constexpr usize bytes_per_mb = 1024 * 1024;
const usize entry_count = (size_mb * bytes_per_mb) / sizeof(tt_entry);

m_data.resize(entry_count);
clear();
}

void transposition_table::prefetch(const zobrist_key key) {
__builtin_prefetch(&m_data[index(key)]);
}

void transposition_table::store(const zobrist_key key, const tt_entry& entry) {
m_data[index(key)] = entry;
}

u64 transposition_table::index(const zobrist_key key) const {
return (static_cast<u128>(key) * static_cast<u128>(m_data.size())) >> 64;
}

u16 transposition_table::hashfull() const {
u16 hashfull{};

for (int i = 0; i < 1000; ++i) {
if (m_data[i].flag() != tt_entry::tt_flag::none)
++hashfull;
}

return hashfull;
}

} // namespace search::tt
Loading