Skip to content

Commit

Permalink
Update type erasure in random_generator to make it more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
Thoemi09 committed Nov 5, 2024
1 parent 5937cd3 commit 56717fb
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 22 deletions.
26 changes: 26 additions & 0 deletions c++/triqs/mc_tools/MersenneRNG.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

#include <cstdlib>
#include <float.h>
#include <iostream>
#include <iterator>
#include <stdexcept>

namespace triqs::mc_tools::RandomGenerators {

Expand Down Expand Up @@ -215,6 +218,29 @@ namespace triqs::mc_tools::RandomGenerators {
// inline double operator()() {
// return ((double)(randomMT())/0xFFFFFFFFU);
// }

// Same interface as boost::variate_generator.
[[nodiscard]] RandMT &engine() { return *this; }
[[nodiscard]] RandMT const &engine() const { return *this; }

// Write a textual representation of a RandMT object to `std::ostream`.
friend std::ostream &operator<<(std::ostream &os, const RandMT &rng) {
for (uint32 i = 0; i < std::size(rng.state); ++i) os << rng.state[i] << ' ';
long distance = rng.next - rng.state; // NOLINT
os << distance << ' ' << rng.left << ' ' << rng.initseed << ' ' << rng.seed_save;
if (!os) throw std::runtime_error("Error writing a triqs::mc_tools::RandomGenerators::RandMT to ostream.");
return os;
}

// Read a textual representation of a RandMT object from `std::istream`.
friend std::istream &operator>>(std::istream &is, RandMT &rng) {
for (uint32 i = 0; i < std::size(rng.state); ++i) { is >> rng.state[i] >> std::ws; }
long distance = 0;
is >> distance >> std::ws >> rng.left >> std::ws >> rng.initseed >> std::ws >> rng.seed_save;
rng.next = rng.state + distance; // NOLINT
if (!is) throw std::runtime_error("Error reading a triqs::mc_tools::RandomGenerators::RandMT from istream.");
return is;
};
};

} // namespace triqs::mc_tools::RandomGenerators
Expand Down
19 changes: 13 additions & 6 deletions c++/triqs/mc_tools/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,30 @@

namespace triqs::mc_tools {

random_generator::random_generator(std::string name, std::uint32_t seed) : name_(std::move(name)) {
random_generator::random_generator(std::string name, std::uint32_t seed, std::size_t buffer_size) : buffer_(buffer_size), name_(std::move(name)) {
initialize_rng(name_, seed);
refill();
}

void random_generator::initialize_rng(std::string const &name, std::uint32_t seed) {
// empty string corresponds to RandMT
if (name_.empty()) {
gen_ = utility::buffered_function<double>(RandomGenerators::RandMT(seed));
if (name.empty()) {
using rng_t = RandomGenerators::RandMT;
ptr_ = std::make_unique<rng_model<rng_t>>(seed);
return;
}

// now boost random number generators
#define DRNG(r, data, XX) \
if (name_ == AS_STRING(XX)) { \
gen_ = utility::buffered_function<double>(boost::variate_generator(boost::XX{seed}, boost::uniform_real<>{})); \
if (name == AS_STRING(XX)) { \
using rng_t = boost::variate_generator<boost::XX, boost::uniform_real<double>>; \
ptr_ = std::make_unique<rng_model<rng_t>>(rng_t{boost::XX{seed}, boost::uniform_real<>{}}); \
return; \
}
BOOST_PP_SEQ_FOR_EACH(DRNG, ~, RNG_LIST)

// throw an exception if the given name is not recognized
throw std::runtime_error(fmt::format("Error in random_generator: RNG with name {} is not supported", name_));
throw std::runtime_error(fmt::format("Error in random_generator::initialize_rng: RNG with name {} is not supported", name));
}

std::string random_generator_names(std::string const &sep) {
Expand Down
108 changes: 99 additions & 9 deletions c++/triqs/mc_tools/random_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@

#pragma once

#include "../utility/buffered_function.hpp"
#include <h5/h5.hpp>

#include <cassert>
#include <cmath>
#include <concepts>
#include <cstdint>
#include <iostream>
#include <memory>
#include <string>
#include <sstream>
#include <utility>
#include <vector>

namespace triqs::mc_tools {
Expand All @@ -41,9 +45,40 @@ namespace triqs::mc_tools {
* calls to the RNG.
*/
class random_generator {
private:
// RNG concept defines the interface for RNGs.
struct rng_concept {
virtual ~rng_concept() = default;
virtual double operator()() = 0;
virtual void refill(std::vector<double> &) = 0;
virtual std::ostream &to_ostream(std::ostream &) const = 0;
virtual std::istream &from_istream(std::istream &) = 0;
};

// RNG model implements the RNG concept by calling the appropriate methods of the type erased object.
template <typename T> struct rng_model : public rng_concept {
T rng_;
rng_model(T rng) : rng_{std::move(rng)} {}
double operator()() override { return rng_(); }
void refill(std::vector<double> &buffer) override {
for (auto &x : buffer) x = rng_();
}
std::ostream &to_ostream(std::ostream &os) const override {
os << rng_.engine();
return os;
}
std::istream &from_istream(std::istream &is) override {
is >> rng_.engine();
return is;
}
};

public:
/// Default seed for the underlying RNG.
static constexpr std::uint32_t default_seed = 198;

/// Default constructor uses Boost's Mersenne Twister 19937 RNG.
random_generator() : random_generator("mt19937", 198) {}
random_generator() : random_generator("mt19937", default_seed) {}

/**
* @brief Construct a random generator by wrapping the specified RNG and seeding it with the given seed.
Expand All @@ -56,8 +91,9 @@ namespace triqs::mc_tools {
*
* @param name Name of the RNG to be used.
* @param seed Seed for the RNG.
* @param buffer_size Size of the buffer used to store random numbers.
*/
random_generator(std::string name, std::uint32_t seed);
random_generator(std::string name, std::uint32_t seed, std::size_t buffer_size = 1000);

/// Deleted copy constructor.
random_generator(random_generator const &) = delete;
Expand All @@ -81,27 +117,33 @@ namespace triqs::mc_tools {
template <typename T>
requires std::integral<T>
T operator()(T i) {
return (i == 1 ? 0 : static_cast<T>(std::floor(i * gen_())));
return (i == 1 ? 0 : static_cast<T>(std::floor(i * this->operator()())));
}

/**
* @brief Look ahead at the next value that will be generated with a call to operator()().
* @return Uniform random double from the interval `[0, 1)`.
*/
[[nodiscard]] double preview() { return gen_.preview(); }
[[nodiscard]] double preview() {
if (idx_ > buffer_.size() - 1) refill();
return buffer_[idx_];
}

/**
* @brief Generate a random sample from the uniform distribution defined on the interval `[0, 1)`.
* @return Uniform random double from the interval `[0, 1)`.
*/
double operator()() { return gen_(); }
double operator()() {
if (idx_ > buffer_.size() - 1) refill();
return buffer_[idx_++];
}

/**
* @brief Generate a random sample from the uniform distribution defined on the interval `[0, b)`.
* @param b Upper bound of the interval.
* @return Uniform random double from the interval `[0, b)`.
*/
double operator()(double b) { return b * (gen_()); }
double operator()(double b) { return b * (this->operator()()); }

/**
* @brief Generate a random sample from the uniform distribution defined on the interval `[a, b)`.
Expand All @@ -112,11 +154,59 @@ namespace triqs::mc_tools {
*/
double operator()(double a, double b) {
assert(b > a);
return a + (b - a) * (gen_());
return a + (b - a) * (this->operator()());
}

/**
* @brief Write the RNG object to HDF5.
*
* @param g h5::group to be written to.
* @param name Name of the dataset/subgroup.
* @param rng RNG object to be written.
*/
friend void h5_write(h5::group g, std::string const &name, random_generator const &rng) {
auto gr = g.create_group(name);
h5::write(gr, "name", rng.name_);
h5::write(gr, "buffer", rng.buffer_);
h5::write(gr, "idx", rng.idx_);
std::ostringstream os;
rng.ptr_->to_ostream(os);
h5::write(gr, "rng", os.str());
}

/**
* @brief Read the RNG object from HDF5.
*
* @param g h5::group to be read from.
* @param name Name of the dataset/subgroup.
* @param rng RNG object to be read into.
*/
friend void h5_read(h5::group g, std::string const &name, random_generator &rng) {
auto gr = g.open_group(name);
h5::read(gr, "name", rng.name_);
h5::read(gr, "buffer", rng.buffer_);
h5::read(gr, "idx", rng.idx_);
rng.initialize_rng(rng.name_, default_seed);
std::string rng_state;
h5::read(gr, "rng", rng_state);
std::istringstream is{rng_state};
rng.ptr_->from_istream(is);
}

private:
// Refill the buffer.
void refill() {
ptr_->refill(buffer_);
idx_ = 0;
}

// Initialize the RNG.
void initialize_rng(std::string const &name, std::uint32_t seed);

private:
utility::buffered_function<double> gen_;
std::unique_ptr<rng_concept> ptr_;
size_t idx_{0};
std::vector<double> buffer_;
std::string name_;
};

Expand Down
68 changes: 61 additions & 7 deletions test/c++/mc_tools/mctools_random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//
// Authors: Michel Ferrero, Olivier Parcollet, Nils Wentzell

#include <triqs/mc_tools/MersenneRNG.hpp>
#include <triqs/mc_tools/random_generator.hpp>
#include <triqs/test_tools/arrays.hpp>

Expand All @@ -27,16 +28,24 @@
#include <boost/random/variate_generator.hpp>
#include <fmt/ranges.h>

#include <sstream>
#include <string>

TEST(TRIQSMCTools, RandomGeneratorNames) {
fmt::print("{}\n", triqs::mc_tools::random_generator_names());
fmt::print("{}\n", triqs::mc_tools::random_generator_names_list());
}

TEST(TRIQSMCTools, RandomGeneratorBoostGenerators) {
TEST(TRIQSMCTools, RandomGeneratorUnderlyingRNG) {
using namespace triqs::mc_tools;
int const seed = 0x18a2b3c4;
auto check_rng = [](auto &&rng, auto &&boost_rng) {
for (int i = 0; i < 100000; ++i) EXPECT_DOUBLE_EQ(rng(), boost_rng());
auto check_rng = [](auto &&rng, auto &&exp_rng) {
for (int i = 0; i < 10; ++i) {
double const val = rng.preview();
double const exp_val = exp_rng();
EXPECT_DOUBLE_EQ(rng(), exp_val);
EXPECT_DOUBLE_EQ(val, exp_val);
}
};

check_rng(random_generator("mt19937", seed), boost::variate_generator(boost::mt19937{seed}, boost::uniform_real<>{}));
Expand All @@ -51,15 +60,60 @@ TEST(TRIQSMCTools, RandomGeneratorBoostGenerators) {
check_rng(random_generator("lagged_fibonacci23209", seed), boost::variate_generator(boost::lagged_fibonacci23209{seed}, boost::uniform_real<>{}));
check_rng(random_generator("lagged_fibonacci44497", seed), boost::variate_generator(boost::lagged_fibonacci44497{seed}, boost::uniform_real<>{}));
check_rng(random_generator("ranlux3", seed), boost::variate_generator(boost::ranlux3{seed}, boost::uniform_real<>{}));
check_rng(random_generator("", seed), triqs::mc_tools::RandomGenerators::RandMT{seed});
}

TEST(TRIQSMCTools, RandomGeneratorRestoreRandMT) {
using namespace triqs::mc_tools::RandomGenerators;
auto rng = RandMT();
auto rng2 = RandMT();
for (int i = 0; i < 10; ++i) rng();
EXPECT_NE(rng(), rng2());
std::stringstream ss;
ss << rng;
ss >> rng2;
for (int i = 0; i < 10; ++i) EXPECT_EQ(rng(), rng2());
}

TEST(TRIQSMCTools, RandomGeneratorHDF5) {
using namespace triqs::mc_tools;
auto check_hdf5 = [](std::string const &name) {
int const seed = 0x18a2b3c4;
auto rng = random_generator(name, seed);
for (int i = 0; i < 10; ++i) rng();
auto rng2 = rw_h5(rng, "mctools_random_generator_" + name, name);
for (int i = 0; i < 10; ++i) EXPECT_DOUBLE_EQ(rng(), rng2());
};

// boost RNGs
for (auto const &name : random_generator_names_list()) check_hdf5(name);

// RandMT
using namespace std::string_literals;
int const seed = 0x18a2b3c4;
auto rng = random_generator("", seed);
for (int i = 0; i < 10; ++i) rng();
auto rng2 = rw_h5(rng, "mctools_random_generator"s + "_RandMT"s, "RandMT");
for (int i = 0; i < 10; ++i) EXPECT_DOUBLE_EQ(rng(), rng2());
}

TEST(TRIQSMCTools, RandomGeneratorRandMTPreview) {
TEST(TRIQSMCTools, RandomGeneratorMoveOperation) {
using namespace triqs::mc_tools;
auto rng = random_generator();
for (int i = 0; i < 100; ++i) {
auto pval = rng.preview();
EXPECT_EQ(pval, rng());
auto rng2 = random_generator();
for (int i = 0; i < 10; ++i) {
rng();
rng2();
}

// move constructor
auto rng3 = std::move(rng);
for (int i = 0; i < 10; ++i) EXPECT_DOUBLE_EQ(rng2(), rng3());

// move assignment
auto rng4 = random_generator();
rng4 = std::move(rng2);
for (int i = 0; i < 10; ++i) EXPECT_DOUBLE_EQ(rng3(), rng4());
}

MAKE_MAIN;

0 comments on commit 56717fb

Please sign in to comment.