From 6e0566a1a251ebdf4001df60fc77300693227ac6 Mon Sep 17 00:00:00 2001 From: Thomas Hahn Date: Sun, 29 Sep 2024 12:09:11 -0400 Subject: [PATCH] Restrict type of MCSignType to double or complex --- c++/triqs/mc_tools/mc_concepts.hpp | 8 ++++++++ c++/triqs/mc_tools/mc_generic.cpp | 27 ++++++++++++++------------- c++/triqs/mc_tools/mc_generic.hpp | 5 +++-- c++/triqs/mc_tools/mc_measure.cpp | 11 ++++++----- c++/triqs/mc_tools/mc_measure.hpp | 6 +++--- c++/triqs/mc_tools/mc_measure_set.cpp | 9 +++++---- c++/triqs/mc_tools/mc_measure_set.hpp | 5 +++-- c++/triqs/mc_tools/mc_move.cpp | 7 ++++--- c++/triqs/mc_tools/mc_move.hpp | 6 +++--- c++/triqs/mc_tools/mc_move_set.cpp | 17 +++++++++-------- c++/triqs/mc_tools/mc_move_set.hpp | 5 +++-- 11 files changed, 61 insertions(+), 45 deletions(-) diff --git a/c++/triqs/mc_tools/mc_concepts.hpp b/c++/triqs/mc_tools/mc_concepts.hpp index 452f1e71e..57d0fd048 100644 --- a/c++/triqs/mc_tools/mc_concepts.hpp +++ b/c++/triqs/mc_tools/mc_concepts.hpp @@ -21,6 +21,7 @@ #include
#include +#include #include #include @@ -28,6 +29,13 @@ namespace triqs::mc_tools { + /** + * @brief Check if a type is either a double or a std::complex. + * @tparam T Type to check. + */ + template + concept DoubleOrComplex = std::same_as || std::same_as>; + /** * @brief Check if a type can be used as a MC move. * @details It checks if the given type has `attempt()` and `accept()` methods that return a value which is diff --git a/c++/triqs/mc_tools/mc_generic.cpp b/c++/triqs/mc_tools/mc_generic.cpp index 00c66b9cc..6af5cf076 100644 --- a/c++/triqs/mc_tools/mc_generic.cpp +++ b/c++/triqs/mc_tools/mc_generic.cpp @@ -18,6 +18,7 @@ // // Authors: Michel Ferrero, Henri Menke, Olivier Parcollet, Priyanka Seth, Hugo U. R. Strand, Nils Wentzell, Thomas Ayral +#include "./mc_concepts.hpp" #include "./mc_generic.hpp" #include "../utility/signal_handler.hpp" #include "../utility/timestamp.hpp" @@ -39,7 +40,7 @@ namespace triqs::mc_tools { - template int mc_generic::run(run_param_t const ¶ms) { + template int mc_generic::run(run_param_t const ¶ms) { EXPECTS(params.cycle_length > 0); EXPECTS(params.stop_callback); EXPECTS(params.after_cycle_duty); @@ -167,7 +168,7 @@ namespace triqs::mc_tools { return status; } - template int mc_generic::warmup(run_param_t const ¶ms) { + template int mc_generic::warmup(run_param_t const ¶ms) { report_(3) << fmt::format("[Rank {}] Performing warum up phase...\n", params.comm.rank()); auto p = params; p.enable_measures = false; @@ -176,7 +177,7 @@ namespace triqs::mc_tools { return status; } - template int mc_generic::accumulate(run_param_t const ¶ms) { + template int mc_generic::accumulate(run_param_t const ¶ms) { report_(3) << fmt::format("[Rank {}] Performing accumulation up phase...\n", params.comm.rank()); auto p = params; p.enable_measures = true; @@ -186,7 +187,7 @@ namespace triqs::mc_tools { return status; } - template + template int mc_generic::run(std::int64_t ncycles, std::int64_t cycle_length, std::function stop_callback, bool enable_measures, mpi::communicator c, bool enable_calibration) { return run({.ncycles = ncycles, @@ -197,23 +198,23 @@ namespace triqs::mc_tools { .enable_calibration = enable_calibration}); } - template + template int mc_generic::warmup(std::int64_t ncycles, std::int64_t cycle_length, std::function stop_callback, MCSignType initial_sign, mpi::communicator c) { return warmup({.ncycles = ncycles, .cycle_length = cycle_length, .stop_callback = stop_callback, .initial_sign = initial_sign, .comm = c}); } - template + template int mc_generic::warmup(std::int64_t ncycles, std::int64_t cycle_length, std::function stop_callback, mpi::communicator c) { return warmup({.ncycles = ncycles, .cycle_length = cycle_length, .stop_callback = stop_callback, .comm = c}); } - template + template int mc_generic::accumulate(std::int64_t ncycles, std::int64_t cycle_length, std::function stop_callback, mpi::communicator c) { return accumulate({.ncycles = ncycles, .cycle_length = cycle_length, .stop_callback = stop_callback, .comm = c}); } - template + template int mc_generic::warmup_and_accumulate(std::int64_t ncycles_warmup, std::int64_t ncycles_acc, std::int64_t cycle_length, std::function stop_callback, MCSignType initial_sign, mpi::communicator c) { auto status = @@ -222,13 +223,13 @@ namespace triqs::mc_tools { return status; } - template + template int mc_generic::warmup_and_accumulate(std::int64_t ncycles_warmup, std::int64_t ncycles_acc, std::int64_t cycle_length, std::function stop_callback, mpi::communicator c) { return warmup_and_accumulate(ncycles_warmup, ncycles_acc, cycle_length, stop_callback, default_initial_sign, c); } - template void mc_generic::collect_results(mpi::communicator const &c) { + template void mc_generic::collect_results(mpi::communicator const &c) { report_(3) << fmt::format("[Rank {}] Collect results: Waiting for all MPI processes to finish accumulating...\n", c.rank()); // collect results from all MPI processes @@ -258,7 +259,7 @@ namespace triqs::mc_tools { } } - template void mc_generic::print_sim_info(run_param_t const ¶ms, std::int64_t cycle_counter) { + template void mc_generic::print_sim_info(run_param_t const ¶ms, std::int64_t cycle_counter) { // current simulation parameters auto const rank = params.comm.rank(); double const runtime = run_timer_; @@ -275,7 +276,7 @@ namespace triqs::mc_tools { if (params.enable_measures) report_(3) << measures_.report(); } - template void mc_generic::metropolis_step() { + template void mc_generic::metropolis_step() { double r = moves_.attempt(); if (rng_() < std::min(1.0, r)) { sign_ *= moves_.accept(); @@ -284,7 +285,7 @@ namespace triqs::mc_tools { } } - template void mc_generic::after_cycle_duties(run_param_t const ¶ms) { + template void mc_generic::after_cycle_duties(run_param_t const ¶ms) { params.after_cycle_duty(); if (params.enable_calibration) moves_.calibrate(params.comm); if (params.enable_measures) { diff --git a/c++/triqs/mc_tools/mc_generic.hpp b/c++/triqs/mc_tools/mc_generic.hpp index c7a9dfa82..4821804e7 100644 --- a/c++/triqs/mc_tools/mc_generic.hpp +++ b/c++/triqs/mc_tools/mc_generic.hpp @@ -20,6 +20,7 @@ #pragma once +#include "./mc_concepts.hpp" #include "./mc_measure_aux_set.hpp" #include "./mc_measure_set.hpp" #include "./mc_move_set.hpp" @@ -73,9 +74,9 @@ namespace triqs::mc_tools { * - A signal is caught by the triqs::utility::signal_handler. * - An exception is caught. * - * @tparam MCSignType Type of the sign/weight of a MC configuration. + * @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration. */ - template class mc_generic { + template class mc_generic { private: // Value to indicate that the initial sign of the user-provided parameters should not be used. static constexpr MCSignType default_initial_sign = std::numeric_limits::infinity(); diff --git a/c++/triqs/mc_tools/mc_measure.cpp b/c++/triqs/mc_tools/mc_measure.cpp index 004e7026f..ecd71e476 100644 --- a/c++/triqs/mc_tools/mc_measure.cpp +++ b/c++/triqs/mc_tools/mc_measure.cpp @@ -17,6 +17,7 @@ // // Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell +#include "./mc_concepts.hpp" #include "./mc_measure.hpp" #include @@ -27,18 +28,18 @@ namespace triqs::mc_tools { - template void measure::collect_results(mpi::communicator const &c) { + template void measure::collect_results(mpi::communicator const &c) { if (enable_timer_) timer_.start(); ptr_->collect_results(c); if (enable_timer_) timer_.stop(); } - template std::string measure::report() const { + template std::string measure::report() const { if (enable_report_) return ptr_->report(); return {}; } - template std::string measure::get_timings(std::string const &name, std::string const &prefix) const { + template std::string measure::get_timings(std::string const &name, std::string const &prefix) const { if (is_measure_set_) { auto str = fmt::format("{}Measure set {}: Duration = {:.4f}\n", prefix, name, duration()); return str + ptr_->ms_get_timings(prefix + " "); @@ -47,12 +48,12 @@ namespace triqs::mc_tools { } } - template std::vector measure::names() const { + template std::vector measure::names() const { if (is_measure_set_) return ptr_->ms_names(); return {}; } - template double measure::duration() const { + template double measure::duration() const { if (enable_timer_) return static_cast(timer_); return 0.0; } diff --git a/c++/triqs/mc_tools/mc_measure.hpp b/c++/triqs/mc_tools/mc_measure.hpp index dcaf4ff8a..ce77a7f39 100644 --- a/c++/triqs/mc_tools/mc_measure.hpp +++ b/c++/triqs/mc_tools/mc_measure.hpp @@ -37,7 +37,7 @@ namespace triqs::mc_tools { // Forward declaration. - template class measure_set; + template class measure_set; /** * @brief Type erasure class for MC measures. @@ -56,9 +56,9 @@ namespace triqs::mc_tools { * - `void h5_write(h5::group, std::string const &, T const &) const`: Writes the measure object of type `T` to HDF5. * - `void h5_read(h5::group, std::string const &, T &)`: Reads the measure object of type `T` from HDF5. * - * @tparam MCSignType Type of the sign/weight of a MC configuration. + * @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration. */ - template class measure { + template class measure { private: // MC measure concept defines the interface for MC measures. struct measure_concept { diff --git a/c++/triqs/mc_tools/mc_measure_set.cpp b/c++/triqs/mc_tools/mc_measure_set.cpp index 9617dfd3e..11a4fcf2a 100644 --- a/c++/triqs/mc_tools/mc_measure_set.cpp +++ b/c++/triqs/mc_tools/mc_measure_set.cpp @@ -17,6 +17,7 @@ // // Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell +#include "./mc_concepts.hpp" #include "./mc_measure_set.hpp" #include @@ -28,7 +29,7 @@ namespace triqs::mc_tools { - template std::vector measure_set::names() const { + template std::vector measure_set::names() const { std::vector res; for (auto &[name, m] : measures_) { res.push_back(name); @@ -38,11 +39,11 @@ namespace triqs::mc_tools { return res; } - template void measure_set::collect_results(const mpi::communicator &c) { + template void measure_set::collect_results(const mpi::communicator &c) { for (auto &[name, m] : measures_) m.collect_results(c); } - template std::string measure_set::report() const { + template std::string measure_set::report() const { std::string res; for (auto &[name, m] : measures_) { auto str = m.report(); @@ -51,7 +52,7 @@ namespace triqs::mc_tools { return res; } - template std::string measure_set::get_timings(std::string const &prefix) const { + template std::string measure_set::get_timings(std::string const &prefix) const { std::string res; for (auto const &[name, m] : measures_) { res += m.get_timings(name, prefix); } return res; diff --git a/c++/triqs/mc_tools/mc_measure_set.hpp b/c++/triqs/mc_tools/mc_measure_set.hpp index 6139b7053..bc219d77b 100644 --- a/c++/triqs/mc_tools/mc_measure_set.hpp +++ b/c++/triqs/mc_tools/mc_measure_set.hpp @@ -19,6 +19,7 @@ #pragma once +#include "./mc_concepts.hpp" #include "./mc_measure.hpp" #include @@ -46,9 +47,9 @@ namespace triqs::mc_tools { * - measure_set::collect_results: Calls the measure::collect_results method for all registered MC measures. * - measure_set::report: Concatenates the reports from all measurements by calling their measure::report method. * - * @tparam MCSignType Type of the sign/weight of a MC configuration. + * @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration. */ - template class measure_set { + template class measure_set { public: /// Map type used for storing the measures. using measure_map_t = std::map>; diff --git a/c++/triqs/mc_tools/mc_move.cpp b/c++/triqs/mc_tools/mc_move.cpp index 874f7808a..37932de5d 100644 --- a/c++/triqs/mc_tools/mc_move.cpp +++ b/c++/triqs/mc_tools/mc_move.cpp @@ -17,6 +17,7 @@ // // Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell +#include "./mc_concepts.hpp" #include "./mc_move.hpp" #include @@ -28,21 +29,21 @@ namespace triqs::mc_tools { - template void move::collect_statistics(mpi::communicator const &c) { + template void move::collect_statistics(mpi::communicator const &c) { std::uint64_t nacc_tot = mpi::all_reduce(nacc_, c); std::uint64_t nprop_tot = mpi::all_reduce(nprop_, c); acc_rate_ = nacc_tot / static_cast(nprop_tot); ptr_->collect_statistics(c); } - template void move::clear_statistics() { + template void move::clear_statistics() { nacc_ = 0; nprop_ = 0; acc_rate_ = -1; ptr_->ms_clear_statistics(); } - template std::string move::get_statistics(std::string const &name, std::string const &prefix) const { + template std::string move::get_statistics(std::string const &name, std::string const &prefix) const { if (is_move_set_) { auto str = fmt::format("{}Move set {}: Proposed = {}, Accepted = {}, Rate = {:.4f}\n", prefix, name, nprop_, nacc_, acc_rate_); return str + ptr_->ms_get_statistics(prefix + " "); diff --git a/c++/triqs/mc_tools/mc_move.hpp b/c++/triqs/mc_tools/mc_move.hpp index 2600a9f82..3c2092475 100644 --- a/c++/triqs/mc_tools/mc_move.hpp +++ b/c++/triqs/mc_tools/mc_move.hpp @@ -35,7 +35,7 @@ namespace triqs::mc_tools { // Forward declaration. - template class move_set; + template class move_set; /** * @brief Type erasure class for MC moves. @@ -59,9 +59,9 @@ namespace triqs::mc_tools { * - `void h5_write(h5::group, std::string const &, T const &) const`: Writes the move object of type `T` to HDF5. * - `void h5_read(h5::group, std::string const &, T &)`: Reads the move object of type `T` from HDF5. * - * @tparam MCSignType Type of the sign/weight of a MC configuration. + * @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration. */ - template class move { + template class move { private: // MC move concept defines the interface for MC moves. struct move_concept { diff --git a/c++/triqs/mc_tools/mc_move_set.cpp b/c++/triqs/mc_tools/mc_move_set.cpp index 8d5e823a9..72a064cc8 100644 --- a/c++/triqs/mc_tools/mc_move_set.cpp +++ b/c++/triqs/mc_tools/mc_move_set.cpp @@ -17,6 +17,7 @@ // // Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell +#include "./mc_concepts.hpp" #include "./mc_move_set.hpp" #include @@ -35,7 +36,7 @@ namespace triqs::mc_tools { - template double move_set::attempt() { + template double move_set::attempt() { assert(std::abs(acc_probs_.back() - 1.0) < 1e-13); // choose a move @@ -46,19 +47,19 @@ namespace triqs::mc_tools { return check_ratio(moves_[current_].attempt()); } - template void move_set::clear_statistics() { + template void move_set::clear_statistics() { for (auto &m : moves_) m.clear_statistics(); } - template void move_set::collect_statistics(mpi::communicator const &c) { + template void move_set::collect_statistics(mpi::communicator const &c) { for (auto &m : moves_) m.collect_statistics(c); } - template void move_set::calibrate(mpi::communicator const &c) { + template void move_set::calibrate(mpi::communicator const &c) { for (auto &m : moves_) m.calibrate(c); } - template std::map move_set::get_acceptance_rates() const { + template std::map move_set::get_acceptance_rates() const { std::map res; for (auto const &[m, name] : itertools::zip(moves_, names_)) { res.insert({name, m.acceptance_rate()}); @@ -68,13 +69,13 @@ namespace triqs::mc_tools { return res; } - template std::string move_set::get_statistics(std::string const &prefix) const { + template std::string move_set::get_statistics(std::string const &prefix) const { std::string str; for (auto const &[m, name] : itertools::zip(moves_, names_)) { str += m.get_statistics(name, prefix); } return str; } - template void move_set::initialize() { + template void move_set::initialize() { // initialize is called in add, so we need to resize the vectors probs_.resize(weights_.size()); acc_probs_.resize(weights_.size()); @@ -87,7 +88,7 @@ namespace triqs::mc_tools { std::partial_sum(probs_.begin(), probs_.end(), acc_probs_.begin()); } - template double move_set::check_ratio(MCSignType ratio) { + template double move_set::check_ratio(MCSignType ratio) { // handle infinities in case of double MC weights if constexpr (std::is_same_v) { if (std::isinf(ratio)) { diff --git a/c++/triqs/mc_tools/mc_move_set.hpp b/c++/triqs/mc_tools/mc_move_set.hpp index c808342e1..abaa244a4 100644 --- a/c++/triqs/mc_tools/mc_move_set.hpp +++ b/c++/triqs/mc_tools/mc_move_set.hpp @@ -19,6 +19,7 @@ #pragma once +#include "./mc_concepts.hpp" #include "./mc_move.hpp" #include "./random_generator.hpp" @@ -51,9 +52,9 @@ namespace triqs::mc_tools { * - move_set::calibrate, move_set::collect_statistics and the HDF5 routines loop over all moves and call the * corresponding method for each registered move. * - * @tparam MCSignType Type of the sign/weight of a MC configuration. + * @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration. */ - template class move_set { + template class move_set { public: /** * @brief Construct a move set with a given random number generator (stored in a `std::reference_wrapper`).