Skip to content

Commit

Permalink
Restrict type of MCSignType to double or complex
Browse files Browse the repository at this point in the history
  • Loading branch information
Thoemi09 committed Oct 2, 2024
1 parent 02420cc commit 6e0566a
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 45 deletions.
8 changes: 8 additions & 0 deletions c++/triqs/mc_tools/mc_concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,21 @@

#include <h5/group.hpp>
#include <mpi/communicator.hpp>
#include <nda/concepts.hpp>

#include <concepts>
#include <string>
#include <utility>

namespace triqs::mc_tools {

/**
* @brief Check if a type is either a double or a std::complex<double>.
* @tparam T Type to check.
*/
template <typename T>
concept DoubleOrComplex = std::same_as<T, double> || std::same_as<T, std::complex<double>>;

/**
* @brief Check if a type can be used as a MC move.
* @details It checks if the given type has `attempt()` and `accept()` methods that return a value which is
Expand Down
27 changes: 14 additions & 13 deletions c++/triqs/mc_tools/mc_generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//
// Authors: Michel Ferrero, Henri Menke, Olivier Parcollet, Priyanka Seth, Hugo U. R. Strand, Nils Wentzell, Thomas Ayral

#include "./mc_concepts.hpp"
#include "./mc_generic.hpp"
#include "../utility/signal_handler.hpp"
#include "../utility/timestamp.hpp"
Expand All @@ -39,7 +40,7 @@

namespace triqs::mc_tools {

template <typename MCSignType> int mc_generic<MCSignType>::run(run_param_t const &params) {
template <DoubleOrComplex MCSignType> int mc_generic<MCSignType>::run(run_param_t const &params) {
EXPECTS(params.cycle_length > 0);
EXPECTS(params.stop_callback);
EXPECTS(params.after_cycle_duty);
Expand Down Expand Up @@ -167,7 +168,7 @@ namespace triqs::mc_tools {
return status;
}

template <typename MCSignType> int mc_generic<MCSignType>::warmup(run_param_t const &params) {
template <DoubleOrComplex MCSignType> int mc_generic<MCSignType>::warmup(run_param_t const &params) {
report_(3) << fmt::format("[Rank {}] Performing warum up phase...\n", params.comm.rank());
auto p = params;
p.enable_measures = false;
Expand All @@ -176,7 +177,7 @@ namespace triqs::mc_tools {
return status;
}

template <typename MCSignType> int mc_generic<MCSignType>::accumulate(run_param_t const &params) {
template <DoubleOrComplex MCSignType> int mc_generic<MCSignType>::accumulate(run_param_t const &params) {
report_(3) << fmt::format("[Rank {}] Performing accumulation up phase...\n", params.comm.rank());
auto p = params;
p.enable_measures = true;
Expand All @@ -186,7 +187,7 @@ namespace triqs::mc_tools {
return status;
}

template <typename MCSignType>
template <DoubleOrComplex MCSignType>
int mc_generic<MCSignType>::run(std::int64_t ncycles, std::int64_t cycle_length, std::function<bool()> stop_callback, bool enable_measures,
mpi::communicator c, bool enable_calibration) {
return run({.ncycles = ncycles,
Expand All @@ -197,23 +198,23 @@ namespace triqs::mc_tools {
.enable_calibration = enable_calibration});
}

template <typename MCSignType>
template <DoubleOrComplex MCSignType>
int mc_generic<MCSignType>::warmup(std::int64_t ncycles, std::int64_t cycle_length, std::function<bool()> stop_callback, MCSignType initial_sign,
mpi::communicator c) {
return warmup({.ncycles = ncycles, .cycle_length = cycle_length, .stop_callback = stop_callback, .initial_sign = initial_sign, .comm = c});
}

template <typename MCSignType>
template <DoubleOrComplex MCSignType>
int mc_generic<MCSignType>::warmup(std::int64_t ncycles, std::int64_t cycle_length, std::function<bool()> stop_callback, mpi::communicator c) {
return warmup({.ncycles = ncycles, .cycle_length = cycle_length, .stop_callback = stop_callback, .comm = c});
}

template <typename MCSignType>
template <DoubleOrComplex MCSignType>
int mc_generic<MCSignType>::accumulate(std::int64_t ncycles, std::int64_t cycle_length, std::function<bool()> stop_callback, mpi::communicator c) {
return accumulate({.ncycles = ncycles, .cycle_length = cycle_length, .stop_callback = stop_callback, .comm = c});
}

template <typename MCSignType>
template <DoubleOrComplex MCSignType>
int mc_generic<MCSignType>::warmup_and_accumulate(std::int64_t ncycles_warmup, std::int64_t ncycles_acc, std::int64_t cycle_length,
std::function<bool()> stop_callback, MCSignType initial_sign, mpi::communicator c) {
auto status =
Expand All @@ -222,13 +223,13 @@ namespace triqs::mc_tools {
return status;
}

template <typename MCSignType>
template <DoubleOrComplex MCSignType>
int mc_generic<MCSignType>::warmup_and_accumulate(std::int64_t ncycles_warmup, std::int64_t ncycles_acc, std::int64_t cycle_length,
std::function<bool()> stop_callback, mpi::communicator c) {
return warmup_and_accumulate(ncycles_warmup, ncycles_acc, cycle_length, stop_callback, default_initial_sign, c);
}

template <typename MCSignType> void mc_generic<MCSignType>::collect_results(mpi::communicator const &c) {
template <DoubleOrComplex MCSignType> void mc_generic<MCSignType>::collect_results(mpi::communicator const &c) {
report_(3) << fmt::format("[Rank {}] Collect results: Waiting for all MPI processes to finish accumulating...\n", c.rank());

// collect results from all MPI processes
Expand Down Expand Up @@ -258,7 +259,7 @@ namespace triqs::mc_tools {
}
}

template <typename MCSignType> void mc_generic<MCSignType>::print_sim_info(run_param_t const &params, std::int64_t cycle_counter) {
template <DoubleOrComplex MCSignType> void mc_generic<MCSignType>::print_sim_info(run_param_t const &params, std::int64_t cycle_counter) {
// current simulation parameters
auto const rank = params.comm.rank();
double const runtime = run_timer_;
Expand All @@ -275,7 +276,7 @@ namespace triqs::mc_tools {
if (params.enable_measures) report_(3) << measures_.report();
}

template <typename MCSignType> void mc_generic<MCSignType>::metropolis_step() {
template <DoubleOrComplex MCSignType> void mc_generic<MCSignType>::metropolis_step() {
double r = moves_.attempt();
if (rng_() < std::min(1.0, r)) {
sign_ *= moves_.accept();
Expand All @@ -284,7 +285,7 @@ namespace triqs::mc_tools {
}
}

template <typename MCSignType> void mc_generic<MCSignType>::after_cycle_duties(run_param_t const &params) {
template <DoubleOrComplex MCSignType> void mc_generic<MCSignType>::after_cycle_duties(run_param_t const &params) {
params.after_cycle_duty();
if (params.enable_calibration) moves_.calibrate(params.comm);
if (params.enable_measures) {
Expand Down
5 changes: 3 additions & 2 deletions c++/triqs/mc_tools/mc_generic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#pragma once

#include "./mc_concepts.hpp"
#include "./mc_measure_aux_set.hpp"
#include "./mc_measure_set.hpp"
#include "./mc_move_set.hpp"
Expand Down Expand Up @@ -73,9 +74,9 @@ namespace triqs::mc_tools {
* - A signal is caught by the triqs::utility::signal_handler.
* - An exception is caught.
*
* @tparam MCSignType Type of the sign/weight of a MC configuration.
* @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration.
*/
template <typename MCSignType> class mc_generic {
template <DoubleOrComplex MCSignType> class mc_generic {
private:
// Value to indicate that the initial sign of the user-provided parameters should not be used.
static constexpr MCSignType default_initial_sign = std::numeric_limits<double>::infinity();
Expand Down
11 changes: 6 additions & 5 deletions c++/triqs/mc_tools/mc_measure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//
// Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell

#include "./mc_concepts.hpp"
#include "./mc_measure.hpp"

#include <fmt/format.h>
Expand All @@ -27,18 +28,18 @@

namespace triqs::mc_tools {

template <typename MCSignType> void measure<MCSignType>::collect_results(mpi::communicator const &c) {
template <DoubleOrComplex MCSignType> void measure<MCSignType>::collect_results(mpi::communicator const &c) {
if (enable_timer_) timer_.start();
ptr_->collect_results(c);
if (enable_timer_) timer_.stop();
}

template <typename MCSignType> std::string measure<MCSignType>::report() const {
template <DoubleOrComplex MCSignType> std::string measure<MCSignType>::report() const {
if (enable_report_) return ptr_->report();
return {};
}

template <typename MCSignType> std::string measure<MCSignType>::get_timings(std::string const &name, std::string const &prefix) const {
template <DoubleOrComplex MCSignType> std::string measure<MCSignType>::get_timings(std::string const &name, std::string const &prefix) const {
if (is_measure_set_) {
auto str = fmt::format("{}Measure set {}: Duration = {:.4f}\n", prefix, name, duration());
return str + ptr_->ms_get_timings(prefix + " ");
Expand All @@ -47,12 +48,12 @@ namespace triqs::mc_tools {
}
}

template <typename MCSignType> std::vector<std::string> measure<MCSignType>::names() const {
template <DoubleOrComplex MCSignType> std::vector<std::string> measure<MCSignType>::names() const {
if (is_measure_set_) return ptr_->ms_names();
return {};
}

template <typename MCSignType> double measure<MCSignType>::duration() const {
template <DoubleOrComplex MCSignType> double measure<MCSignType>::duration() const {
if (enable_timer_) return static_cast<double>(timer_);
return 0.0;
}
Expand Down
6 changes: 3 additions & 3 deletions c++/triqs/mc_tools/mc_measure.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
namespace triqs::mc_tools {

// Forward declaration.
template <typename MCSignType> class measure_set;
template <DoubleOrComplex MCSignType> class measure_set;

/**
* @brief Type erasure class for MC measures.
Expand All @@ -56,9 +56,9 @@ namespace triqs::mc_tools {
* - `void h5_write(h5::group, std::string const &, T const &) const`: Writes the measure object of type `T` to HDF5.
* - `void h5_read(h5::group, std::string const &, T &)`: Reads the measure object of type `T` from HDF5.
*
* @tparam MCSignType Type of the sign/weight of a MC configuration.
* @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration.
*/
template <typename MCSignType> class measure {
template <DoubleOrComplex MCSignType> class measure {
private:
// MC measure concept defines the interface for MC measures.
struct measure_concept {
Expand Down
9 changes: 5 additions & 4 deletions c++/triqs/mc_tools/mc_measure_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//
// Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell

#include "./mc_concepts.hpp"
#include "./mc_measure_set.hpp"

#include <fmt/format.h>
Expand All @@ -28,7 +29,7 @@

namespace triqs::mc_tools {

template <typename MCSignType> std::vector<std::string> measure_set<MCSignType>::names() const {
template <DoubleOrComplex MCSignType> std::vector<std::string> measure_set<MCSignType>::names() const {
std::vector<std::string> res;
for (auto &[name, m] : measures_) {
res.push_back(name);
Expand All @@ -38,11 +39,11 @@ namespace triqs::mc_tools {
return res;
}

template <typename MCSignType> void measure_set<MCSignType>::collect_results(const mpi::communicator &c) {
template <DoubleOrComplex MCSignType> void measure_set<MCSignType>::collect_results(const mpi::communicator &c) {
for (auto &[name, m] : measures_) m.collect_results(c);
}

template <typename MCSignType> std::string measure_set<MCSignType>::report() const {
template <DoubleOrComplex MCSignType> std::string measure_set<MCSignType>::report() const {
std::string res;
for (auto &[name, m] : measures_) {
auto str = m.report();
Expand All @@ -51,7 +52,7 @@ namespace triqs::mc_tools {
return res;
}

template <typename MCSignType> std::string measure_set<MCSignType>::get_timings(std::string const &prefix) const {
template <DoubleOrComplex MCSignType> std::string measure_set<MCSignType>::get_timings(std::string const &prefix) const {
std::string res;
for (auto const &[name, m] : measures_) { res += m.get_timings(name, prefix); }
return res;
Expand Down
5 changes: 3 additions & 2 deletions c++/triqs/mc_tools/mc_measure_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#pragma once

#include "./mc_concepts.hpp"
#include "./mc_measure.hpp"

#include <fmt/format.h>
Expand Down Expand Up @@ -46,9 +47,9 @@ namespace triqs::mc_tools {
* - measure_set::collect_results: Calls the measure::collect_results method for all registered MC measures.
* - measure_set::report: Concatenates the reports from all measurements by calling their measure::report method.
*
* @tparam MCSignType Type of the sign/weight of a MC configuration.
* @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration.
*/
template <typename MCSignType> class measure_set {
template <DoubleOrComplex MCSignType> class measure_set {
public:
/// Map type used for storing the measures.
using measure_map_t = std::map<std::string, measure<MCSignType>>;
Expand Down
7 changes: 4 additions & 3 deletions c++/triqs/mc_tools/mc_move.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//
// Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell

#include "./mc_concepts.hpp"
#include "./mc_move.hpp"

#include <fmt/format.h>
Expand All @@ -28,21 +29,21 @@

namespace triqs::mc_tools {

template <typename MCSignType> void move<MCSignType>::collect_statistics(mpi::communicator const &c) {
template <DoubleOrComplex MCSignType> void move<MCSignType>::collect_statistics(mpi::communicator const &c) {
std::uint64_t nacc_tot = mpi::all_reduce(nacc_, c);
std::uint64_t nprop_tot = mpi::all_reduce(nprop_, c);
acc_rate_ = nacc_tot / static_cast<double>(nprop_tot);
ptr_->collect_statistics(c);
}

template <typename MCSignType> void move<MCSignType>::clear_statistics() {
template <DoubleOrComplex MCSignType> void move<MCSignType>::clear_statistics() {
nacc_ = 0;
nprop_ = 0;
acc_rate_ = -1;
ptr_->ms_clear_statistics();
}

template <typename MCSignType> std::string move<MCSignType>::get_statistics(std::string const &name, std::string const &prefix) const {
template <DoubleOrComplex MCSignType> std::string move<MCSignType>::get_statistics(std::string const &name, std::string const &prefix) const {
if (is_move_set_) {
auto str = fmt::format("{}Move set {}: Proposed = {}, Accepted = {}, Rate = {:.4f}\n", prefix, name, nprop_, nacc_, acc_rate_);
return str + ptr_->ms_get_statistics(prefix + " ");
Expand Down
6 changes: 3 additions & 3 deletions c++/triqs/mc_tools/mc_move.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
namespace triqs::mc_tools {

// Forward declaration.
template <typename MCSignType> class move_set;
template <DoubleOrComplex MCSignType> class move_set;

/**
* @brief Type erasure class for MC moves.
Expand All @@ -59,9 +59,9 @@ namespace triqs::mc_tools {
* - `void h5_write(h5::group, std::string const &, T const &) const`: Writes the move object of type `T` to HDF5.
* - `void h5_read(h5::group, std::string const &, T &)`: Reads the move object of type `T` from HDF5.
*
* @tparam MCSignType Type of the sign/weight of a MC configuration.
* @tparam MCSignType triqs::mc_tools::DoubleOrComplex type of the sign/weight of a MC configuration.
*/
template <typename MCSignType> class move {
template <DoubleOrComplex MCSignType> class move {
private:
// MC move concept defines the interface for MC moves.
struct move_concept {
Expand Down
Loading

0 comments on commit 6e0566a

Please sign in to comment.