Skip to content

Commit

Permalink
Merge pull request #39 from seldon-code/bivariate_gaussian_copula
Browse files Browse the repository at this point in the history
Bivariate gaussian copula for reluctances and activities
  • Loading branch information
MSallermann authored Mar 25, 2024
2 parents c061e83 + ade8ee9 commit 1180221
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 25 deletions.
2 changes: 1 addition & 1 deletion include/connectivity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class TarjanConnectivityAlgo
{
lowest[v] = std::min( lowest[v], num[u] );
} // u not processed
} // u has been visited
} // u has been visited
}

// Now v has been processed
Expand Down
20 changes: 10 additions & 10 deletions include/models/ActivityDrivenModel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ActivityDrivenModelAbstract : public Model<AgentT_>
}
}

void iteration() override{};
void iteration() override {};

protected:
NetworkT & network;
Expand Down Expand Up @@ -113,6 +113,8 @@ class ActivityDrivenModelAbstract : public Model<AgentT_>
power_law_distribution<> dist_activity( eps, gamma );
truncated_normal_distribution<> dist_reluctance( reluctance_mean, reluctance_sigma, reluctance_eps );

bivariate_gaussian_copula copula( covariance_factor, dist_activity, dist_reluctance );

auto mean_activity = dist_activity.mean();

// Initial conditions for the opinions, initialize to [-1,1]
Expand All @@ -121,18 +123,16 @@ class ActivityDrivenModelAbstract : public Model<AgentT_>
{
network.agents[i].data.opinion = dis_opinion( gen ); // Draw the opinion value

if( !mean_activities )
{ // Draw from a power law distribution (1-gamma)/(1-eps^(1-gamma)) * a^(-gamma)
network.agents[i].data.activity = dist_activity( gen );
}
else
{
network.agents[i].data.activity = mean_activity;
}
auto res = copula( gen );
network.agents[i].data.activity = res[0];

if( use_reluctances )
{
network.agents[i].data.reluctance = dist_reluctance( gen );
network.agents[i].data.reluctance = res[1];
}
if( mean_activities )
{
network.agents[i].data.activity = mean_activity;
}
}

Expand Down
112 changes: 112 additions & 0 deletions include/util/erfinv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#pragma once
#include <cmath>
#include <limits>

namespace Seldon::Math
{

constexpr long double pi = 3.1415926535897932384626433832795028841971693993751L;
constexpr long double sqrt_pi = 1.7724538509055160272981674833411451827975494561224L;

// Implementation adapted from https://github.com/lakshayg/erfinv same as used in golang math library
template<typename T>
T erfinv( T x )
{
if( x < -1 || x > 1 )
{
return std::numeric_limits<T>::quiet_NaN();
}
else if( x == 1.0 )
{
return std::numeric_limits<T>::infinity();
}
else if( x == -1.0 )
{
return -std::numeric_limits<T>::infinity();
}

const T LN2 = 6.931471805599453094172321214581e-1L;

const T A0 = 1.1975323115670912564578e0L;
const T A1 = 4.7072688112383978012285e1L;
const T A2 = 6.9706266534389598238465e2L;
const T A3 = 4.8548868893843886794648e3L;
const T A4 = 1.6235862515167575384252e4L;
const T A5 = 2.3782041382114385731252e4L;
const T A6 = 1.1819493347062294404278e4L;
const T A7 = 8.8709406962545514830200e2L;

const T B0 = 1.0000000000000000000e0L;
const T B1 = 4.2313330701600911252e1L;
const T B2 = 6.8718700749205790830e2L;
const T B3 = 5.3941960214247511077e3L;
const T B4 = 2.1213794301586595867e4L;
const T B5 = 3.9307895800092710610e4L;
const T B6 = 2.8729085735721942674e4L;
const T B7 = 5.2264952788528545610e3L;

const T C0 = 1.42343711074968357734e0L;
const T C1 = 4.63033784615654529590e0L;
const T C2 = 5.76949722146069140550e0L;
const T C3 = 3.64784832476320460504e0L;
const T C4 = 1.27045825245236838258e0L;
const T C5 = 2.41780725177450611770e-1L;
const T C6 = 2.27238449892691845833e-2L;
const T C7 = 7.74545014278341407640e-4L;

const T D0 = 1.4142135623730950488016887e0L;
const T D1 = 2.9036514445419946173133295e0L;
const T D2 = 2.3707661626024532365971225e0L;
const T D3 = 9.7547832001787427186894837e-1L;
const T D4 = 2.0945065210512749128288442e-1L;
const T D5 = 2.1494160384252876777097297e-2L;
const T D6 = 7.7441459065157709165577218e-4L;
const T D7 = 1.4859850019840355905497876e-9L;

const T E0 = 6.65790464350110377720e0L;
const T E1 = 5.46378491116411436990e0L;
const T E2 = 1.78482653991729133580e0L;
const T E3 = 2.96560571828504891230e-1L;
const T E4 = 2.65321895265761230930e-2L;
const T E5 = 1.24266094738807843860e-3L;
const T E6 = 2.71155556874348757815e-5L;
const T E7 = 2.01033439929228813265e-7L;

const T F0 = 1.414213562373095048801689e0L;
const T F1 = 8.482908416595164588112026e-1L;
const T F2 = 1.936480946950659106176712e-1L;
const T F3 = 2.103693768272068968719679e-2L;
const T F4 = 1.112800997078859844711555e-3L;
const T F5 = 2.611088405080593625138020e-5L;
const T F6 = 2.010321207683943062279931e-7L;
const T F7 = 2.891024605872965461538222e-15L;

T abs_x = std::abs( x );

T r, num, den;

if( abs_x <= 0.85 )
{
r = 0.180625 - 0.25 * x * x;
num = ( ( ( ( ( ( ( A7 * r + A6 ) * r + A5 ) * r + A4 ) * r + A3 ) * r + A2 ) * r + A1 ) * r + A0 );
den = ( ( ( ( ( ( ( B7 * r + B6 ) * r + B5 ) * r + B4 ) * r + B3 ) * r + B2 ) * r + B1 ) * r + B0 );
return x * num / den;
}

r = std::sqrt( LN2 - std::log1p( -abs_x ) );
if( r <= 5.0 )
{
r = r - 1.6;
num = ( ( ( ( ( ( ( C7 * r + C6 ) * r + C5 ) * r + C4 ) * r + C3 ) * r + C2 ) * r + C1 ) * r + C0 );
den = ( ( ( ( ( ( ( D7 * r + D6 ) * r + D5 ) * r + D4 ) * r + D3 ) * r + D2 ) * r + D1 ) * r + D0 );
}
else
{
r = r - 5.0;
num = ( ( ( ( ( ( ( E7 * r + E6 ) * r + E5 ) * r + E4 ) * r + E3 ) * r + E2 ) * r + E1 ) * r + E0 );
den = ( ( ( ( ( ( ( F7 * r + F6 ) * r + F5 ) * r + F4 ) * r + F3 ) * r + F2 ) * r + F1 ) * r + F0 );
}

return std::copysign<T>( num / den, x );
}
} // namespace Seldon::Math
123 changes: 112 additions & 11 deletions include/util/math.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#pragma once
#include "fmt/core.h"
#include "util/erfinv.hpp"
#include <algorithm>
#include <cstddef>
#include <optional>
Expand Down Expand Up @@ -129,9 +131,19 @@ class power_law_distribution

template<typename Generator>
ScalarT operator()( Generator & gen )
{
return inverse_cdf( dist( gen ) );
}

ScalarT pdf( ScalarT x )
{
return ( 1.0 - gamma ) / ( 1.0 - std::pow( eps, ( 1 - gamma ) ) * std::pow( x, ( -gamma ) ) );
}

ScalarT inverse_cdf( ScalarT x )
{
return std::pow(
( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * dist( gen ) + std::pow( eps, ( 1.0 - gamma ) ),
( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * x + std::pow( eps, ( 1.0 - gamma ) ),
( 1.0 / ( 1.0 - gamma ) ) );
}

Expand All @@ -154,25 +166,114 @@ class truncated_normal_distribution
ScalarT mean{};
ScalarT sigma{};
ScalarT eps{};
std::normal_distribution<ScalarT> normal_dist{};
size_t max_tries = 5000;
std::uniform_real_distribution<ScalarT> uniform_dist{};

ScalarT inverse_cdf_gauss( ScalarT y )
{
return Math::erfinv( 2.0 * y - 1 ) * std::sqrt( 2.0 ) * sigma + mean;
}

ScalarT cdf_gauss( ScalarT x )
{
return 0.5 * ( 1 + std::erf( ( x - mean ) / ( sigma * std::sqrt( 2.0 ) ) ) );
}

ScalarT pdf_gauss( ScalarT x )
{
return 1.0 / ( sigma * std::sqrt( 2 * Math::pi ) ) * std::exp( -0.5 * std::pow( ( ( x - mean ) / sigma ), 2 ) );
}

public:
truncated_normal_distribution( ScalarT mean, ScalarT sigma, ScalarT eps )
: mean( mean ), sigma( sigma ), eps( eps ), normal_dist( std::normal_distribution<ScalarT>( mean, sigma ) )
: mean( mean ), sigma( sigma ), eps( eps ), uniform_dist( 0, 1 )
{
}

template<typename Generator>
ScalarT operator()( Generator & gen )
{
for( size_t i = 0; i < max_tries; i++ )
{
auto sample = normal_dist( gen );
if( sample > eps )
return sample;
}
return eps;
return inverse_cdf( uniform_dist( gen ) );
}

ScalarT inverse_cdf( ScalarT y )
{
return inverse_cdf_gauss( y * ( 1.0 - cdf_gauss( eps ) ) + cdf_gauss( eps ) );
}

ScalarT pdf( ScalarT x )
{
if( x < eps )
return 0.0;
else
return 1.0 / ( 1.0 - cdf_gauss( eps ) ) * pdf_gauss( x );
}
};

/**
* @brief Bivariate normal distribution
* with mean mu = [0,0]
* and covariance matrix Sigma = [[1, cov], [cov, 1]]
* |cov| < 1 is required
*/
template<typename ScalarT = double>
class bivariate_normal_distribution
{
private:
ScalarT covariance;
std::normal_distribution<ScalarT> normal_dist{};

public:
bivariate_normal_distribution( ScalarT covariance ) : covariance( covariance ) {}

template<typename Generator>
std::array<ScalarT, 2> operator()( Generator & gen )
{
ScalarT n1 = normal_dist( gen );
ScalarT n2 = normal_dist( gen );

ScalarT r1 = n1;
ScalarT r2 = covariance * n1 + std::sqrt( 1.0 - covariance * covariance ) * n2;

return { r1, r2 };
}
};

template<typename ScalarT, typename dist1T, typename dist2T>
class bivariate_gaussian_copula
{
private:
ScalarT covariance;
bivariate_normal_distribution<ScalarT> biv_normal_dist{};
// std::normal_distribution<ScalarT> normal_dist{};

// Cumulative probability function for gaussian with mean 0 and variance 1
ScalarT cdf_gauss( ScalarT x )
{
return 0.5 * ( 1 + std::erf( ( x ) / std::sqrt( 2.0 ) ) );
}

dist1T dist1;
dist2T dist2;

public:
bivariate_gaussian_copula( ScalarT covariance, dist1T dist1, dist2T dist2 )
: covariance( covariance ),
biv_normal_dist( bivariate_normal_distribution( covariance ) ),
dist1( dist1 ),
dist2( dist2 )
{
}

template<typename Generator>
std::array<ScalarT, 2> operator()( Generator & gen )
{
// 1. Draw from bivariate gaussian
auto z = biv_normal_dist( gen );
// 2. Transform marginals to unit interval
std::array<ScalarT, 2> z_unit = { cdf_gauss( z[0] ), cdf_gauss( z[1] ) };
// 3. Apply inverse transform sampling
std::array<ScalarT, 2> res = { dist1.inverse_cdf( z_unit[0] ), dist2.inverse_cdf( z_unit[1] ) };
return res;
}
};

Expand Down
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ tests = [
['Test_Sampling', 'test/test_sampling.cpp'],
['Test_IO', 'test/test_io.cpp'],
['Test_Util', 'test/test_util.cpp'],
['Test_Prob', 'test/test_probability_distributions.cpp'],
]

Catch2 = dependency('Catch2', method : 'cmake', modules : ['Catch2::Catch2WithMain', 'Catch2::Catch2'])
Expand Down
3 changes: 1 addition & 2 deletions src/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ void validate_settings( const SimulationOptions & options )
{
const std::string basic_deff_msg
= "The basic Deffuant model has not been implemented with multiple dimensions";
check(
name_and_var( model_settings.dim ), []( auto x ) { return x == 1; }, basic_deff_msg );
check( name_and_var( model_settings.dim ), []( auto x ) { return x == 1; }, basic_deff_msg );
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ TEST_CASE( "Testing the network class" )
auto weight = buffer_w[i_neighbour];
std::tuple<size_t, size_t, Network::WeightT> edge{
neighbour, i_agent, weight
}; // Note that i_agent and neighbour are flipped compared to before
}; // Note that i_agent and neighbour are flipped compared to before
REQUIRE( old_edges.contains( edge ) ); // can we find the transposed edge?
old_edges.extract( edge ); // extract the edge afterwards
}
Expand Down
Loading

0 comments on commit 1180221

Please sign in to comment.