diff --git a/include/util/math.hpp b/include/util/math.hpp index e6041fc..bbde5e6 100644 --- a/include/util/math.hpp +++ b/include/util/math.hpp @@ -1,4 +1,6 @@ #pragma once +#include "fmt/core.h" +#include "util/erfinv.hpp" #include #include #include @@ -129,9 +131,14 @@ class power_law_distribution template ScalarT operator()( Generator & gen ) + { + return inverse_cumulative_probability( dist( gen ) ); + } + + ScalarT inverse_cumulative_probability( ScalarT x ) { return std::pow( - ( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * dist( gen ) + std::pow( eps, ( 1.0 - gamma ) ), + ( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * x + std::pow( eps, ( 1.0 - gamma ) ), ( 1.0 / ( 1.0 - gamma ) ) ); } @@ -157,6 +164,16 @@ class truncated_normal_distribution std::normal_distribution normal_dist{}; size_t max_tries = 5000; + ScalarT inverse_cum_gauss( ScalarT y ) + { + return erfinv( 2.0 * y - 1 ) * std::sqrt( 2.0 ) * sigma + mean; + } + + ScalarT cum_gauss( ScalarT x ) + { + return 0.5 * ( 1 + std::erf( ( x - mean ) / ( sigma * std::sqrt( 2.0 ) ) ) ); + } + public: truncated_normal_distribution( ScalarT mean, ScalarT sigma, ScalarT eps ) : mean( mean ), sigma( sigma ), eps( eps ), normal_dist( std::normal_distribution( mean, sigma ) ) @@ -174,6 +191,81 @@ class truncated_normal_distribution } return eps; } + + ScalarT inverse_cumulative_probability( ScalarT y ) + { + return inverse_cum_gauss( + y * ( 1.0 - cum_gauss( eps, sigma, mean ) ) + cum_gauss( eps, sigma, mean ), sigma, mean ); + } +}; + +/** + * @brief Bivariate normal distribution + * with mean mu = [0,0] + * and covariance matrix Sigma = [[1, cov], [cov, 1]] + * |cov| < 1 is required + */ +template +class bivariate_normal_distribution +{ +private: + ScalarT covariance; + std::normal_distribution normal_dist{}; + +public: + bivariate_normal_distribution( ScalarT covariance ) : covariance( covariance ) {} + + template + std::array operator()( Generator & gen ) + { + ScalarT n1 = normal_dist( gen ); + ScalarT n2 = normal_dist( gen ); + + ScalarT r1 = n1; + ScalarT r2 = covariance * n1 + std::sqrt( 1 - covariance * covariance ); + + return { r1, r2 }; + } +}; + +template +class bivariate_gaussian_copula +{ +private: + ScalarT covariance; + bivariate_normal_distribution bivariate_normal_dist{}; + // std::normal_distribution normal_dist{}; + + // Cumulative probability function for gaussian with mean 0 and variance 1 + ScalarT cum_gauss( ScalarT x ) + { + return 0.5 * ( 1 + std::erf( ( x ) / std::sqrt( 2.0 ) ) ); + } + + dist1T dist1; + dist2T dist2; + +public: + bivariate_gaussian_copula( ScalarT covariance, dist1T dist1, dist2T dist2 ) + : covariance( covariance ), + dist1( dist1 ), + dist2( dist2 ), + bivariate_normal_dist( bivariate_normal_dist( covariance ) ) + { + } + + template + std::array operator()( Generator & gen ) + { + // 1. Draw from bivariate gaussian + auto z = bivariate_normal_dist( gen ); + // 2. Transform marginals to unit interval + std::array z_unit = { cum_gauss( z[0] ), cum_gauss( z[1] ) }; + // 3. Apply inverse transform sampling + std::array res + = { dist1.inverse_cumulative_probability( z_unit[0] ), dist2.inverse_cumulative_probability( z_unit[1] ) }; + return res; + } }; template diff --git a/meson.build b/meson.build index e7017c3..06e7941 100644 --- a/meson.build +++ b/meson.build @@ -32,6 +32,7 @@ tests = [ ['Test_Sampling', 'test/test_sampling.cpp'], ['Test_IO', 'test/test_io.cpp'], ['Test_Util', 'test/test_util.cpp'], + ['Test_Prob', 'test/test_probability_distributions.cpp'], ] Catch2 = dependency('Catch2', method : 'cmake', modules : ['Catch2::Catch2WithMain', 'Catch2::Catch2']) diff --git a/test/test_probability_distributions.cpp b/test/test_probability_distributions.cpp new file mode 100644 index 0000000..74fd6ab --- /dev/null +++ b/test/test_probability_distributions.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +#include "util/math.hpp" +#include +#include +#include +#include +#include +namespace fs = std::filesystem; + +template +std::ostream & operator<<( std::ostream & os, std::array const & v1 ) +{ + std::for_each( begin( v1 ), end( v1 ), [&os]( int val ) { os << val << " "; } ); + return os; +} + +// Samples the distribution n_samples times and writes results to file +template +void write_results_to_file( int N_Samples, distT dist, const std::string & filename ) +{ + auto proj_root_path = fs::current_path(); + auto file = proj_root_path / fs::path( "/test/output/" + filename ); + fs::create_directories( file ); + + auto gen = std::mt19937( 0 ); + + std::ofstream filestream( file ); + filestream << std::setprecision( 16 ); + + for( int i = 0; i < N_Samples; i++ ) + { + filestream << dist( gen ) << "\n"; + } + filestream.close(); +} + +TEST_CASE( "Test the probability distributions", "[prob]" ) +{ + write_results_to_file( 10000, Seldon::truncated_normal_distribution( 1.0, 0.5, 0.1 ), "truncated_normal.txt" ); + write_results_to_file( 10000, Seldon::power_law_distribution( 0.01, 2.1 ), "power_law.txt" ); + write_results_to_file( 10000, Seldon::bivariate_normal_distribution( 0.5 ), "bivariate_normal.txt" ); +} + +// TEST_CASE( "Test reading in the agents from a file", "[io_agents]" ) +// { +// using namespace Seldon; +// using namespace Catch::Matchers; + +// auto proj_root_path = fs::current_path(); +// auto network_file = proj_root_path / fs::path( "test/res/opinions.txt" ); + +// auto agents = Seldon::agents_from_file( network_file ); + +// std::vector opinions_expected = { 2.1127107987061544, 0.8088982488089491, -0.8802809369462433 }; +// std::vector activities_expected = { 0.044554683389757696, 0.015813166022685163, 0.015863953902810535 }; +// std::vector reluctances_expected = { 1.0, 1.0, 2.3 }; + +// REQUIRE( agents.size() == 3 ); + +// for( size_t i = 0; i < agents.size(); i++ ) +// { +// fmt::print( "{}", i ); +// REQUIRE_THAT( agents[i].data.opinion, Catch::Matchers::WithinAbs( opinions_expected[i], 1e-16 ) ); +// REQUIRE_THAT( agents[i].data.activity, Catch::Matchers::WithinAbs( activities_expected[i], 1e-16 ) ); +// REQUIRE_THAT( agents[i].data.reluctance, Catch::Matchers::WithinAbs( reluctances_expected[i], 1e-16 ) ); +// } +// } \ No newline at end of file