Skip to content

Commit

Permalink
Math: Implement bivariate gaussian copula (WIP)
Browse files Browse the repository at this point in the history
Co-authored-by: Amrita Goswami <[email protected]>
  • Loading branch information
MSallermann and amritagos committed Mar 23, 2024
1 parent f7386a3 commit 25c78d4
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 1 deletion.
94 changes: 93 additions & 1 deletion include/util/math.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#pragma once
#include "fmt/core.h"
#include "util/erfinv.hpp"
#include <algorithm>
#include <cstddef>
#include <optional>
Expand Down Expand Up @@ -129,9 +131,14 @@ class power_law_distribution

template<typename Generator>
ScalarT operator()( Generator & gen )
{
return inverse_cumulative_probability( dist( gen ) );
}

ScalarT inverse_cumulative_probability( ScalarT x )
{
return std::pow(
( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * dist( gen ) + std::pow( eps, ( 1.0 - gamma ) ),
( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * x + std::pow( eps, ( 1.0 - gamma ) ),
( 1.0 / ( 1.0 - gamma ) ) );
}

Expand All @@ -157,6 +164,16 @@ class truncated_normal_distribution
std::normal_distribution<ScalarT> normal_dist{};
size_t max_tries = 5000;

ScalarT inverse_cum_gauss( ScalarT y )
{
return erfinv( 2.0 * y - 1 ) * std::sqrt( 2.0 ) * sigma + mean;
}

ScalarT cum_gauss( ScalarT x )
{
return 0.5 * ( 1 + std::erf( ( x - mean ) / ( sigma * std::sqrt( 2.0 ) ) ) );
}

public:
truncated_normal_distribution( ScalarT mean, ScalarT sigma, ScalarT eps )
: mean( mean ), sigma( sigma ), eps( eps ), normal_dist( std::normal_distribution<ScalarT>( mean, sigma ) )
Expand All @@ -174,6 +191,81 @@ class truncated_normal_distribution
}
return eps;
}

ScalarT inverse_cumulative_probability( ScalarT y )
{
return inverse_cum_gauss(
y * ( 1.0 - cum_gauss( eps, sigma, mean ) ) + cum_gauss( eps, sigma, mean ), sigma, mean );
}
};

/**
* @brief Bivariate normal distribution
* with mean mu = [0,0]
* and covariance matrix Sigma = [[1, cov], [cov, 1]]
* |cov| < 1 is required
*/
template<typename ScalarT = double>
class bivariate_normal_distribution
{
private:
ScalarT covariance;
std::normal_distribution<ScalarT> normal_dist{};

public:
bivariate_normal_distribution( ScalarT covariance ) : covariance( covariance ) {}

template<typename Generator>
std::array<ScalarT, 2> operator()( Generator & gen )
{
ScalarT n1 = normal_dist( gen );
ScalarT n2 = normal_dist( gen );

ScalarT r1 = n1;
ScalarT r2 = covariance * n1 + std::sqrt( 1 - covariance * covariance );

return { r1, r2 };
}
};

template<typename ScalarT, typename dist1T, typename dist2T>
class bivariate_gaussian_copula
{
private:
ScalarT covariance;
bivariate_normal_distribution<ScalarT> bivariate_normal_dist{};
// std::normal_distribution<ScalarT> normal_dist{};

// Cumulative probability function for gaussian with mean 0 and variance 1
ScalarT cum_gauss( ScalarT x )
{
return 0.5 * ( 1 + std::erf( ( x ) / std::sqrt( 2.0 ) ) );
}

dist1T dist1;
dist2T dist2;

public:
bivariate_gaussian_copula( ScalarT covariance, dist1T dist1, dist2T dist2 )
: covariance( covariance ),
dist1( dist1 ),
dist2( dist2 ),
bivariate_normal_dist( bivariate_normal_dist( covariance ) )
{
}

template<typename Generator>
std::array<ScalarT, 2> operator()( Generator & gen )
{
// 1. Draw from bivariate gaussian
auto z = bivariate_normal_dist( gen );
// 2. Transform marginals to unit interval
std::array<ScalarT, 2> z_unit = { cum_gauss( z[0] ), cum_gauss( z[1] ) };
// 3. Apply inverse transform sampling
std::array<ScalarT, 2> res
= { dist1.inverse_cumulative_probability( z_unit[0] ), dist2.inverse_cumulative_probability( z_unit[1] ) };
return res;
}
};

template<typename T>
Expand Down
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ tests = [
['Test_Sampling', 'test/test_sampling.cpp'],
['Test_IO', 'test/test_io.cpp'],
['Test_Util', 'test/test_util.cpp'],
['Test_Prob', 'test/test_probability_distributions.cpp'],
]

Catch2 = dependency('Catch2', method : 'cmake', modules : ['Catch2::Catch2WithMain', 'Catch2::Catch2'])
Expand Down
70 changes: 70 additions & 0 deletions test/test_probability_distributions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include <catch2/matchers/catch_matchers_range_equals.hpp>

#include "util/math.hpp"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <random>
namespace fs = std::filesystem;

template<std::size_t N>
std::ostream & operator<<( std::ostream & os, std::array<double, N> const & v1 )
{
std::for_each( begin( v1 ), end( v1 ), [&os]( int val ) { os << val << " "; } );
return os;
}

// Samples the distribution n_samples times and writes results to file
template<typename distT>
void write_results_to_file( int N_Samples, distT dist, const std::string & filename )
{
auto proj_root_path = fs::current_path();
auto file = proj_root_path / fs::path( "/test/output/" + filename );
fs::create_directories( file );

auto gen = std::mt19937( 0 );

std::ofstream filestream( file );
filestream << std::setprecision( 16 );

for( int i = 0; i < N_Samples; i++ )
{
filestream << dist( gen ) << "\n";
}
filestream.close();
}

TEST_CASE( "Test the probability distributions", "[prob]" )
{
write_results_to_file( 10000, Seldon::truncated_normal_distribution( 1.0, 0.5, 0.1 ), "truncated_normal.txt" );
write_results_to_file( 10000, Seldon::power_law_distribution( 0.01, 2.1 ), "power_law.txt" );
write_results_to_file( 10000, Seldon::bivariate_normal_distribution( 0.5 ), "bivariate_normal.txt" );
}

// TEST_CASE( "Test reading in the agents from a file", "[io_agents]" )
// {
// using namespace Seldon;
// using namespace Catch::Matchers;

// auto proj_root_path = fs::current_path();
// auto network_file = proj_root_path / fs::path( "test/res/opinions.txt" );

// auto agents = Seldon::agents_from_file<ActivityDrivenModel::AgentT>( network_file );

// std::vector<double> opinions_expected = { 2.1127107987061544, 0.8088982488089491, -0.8802809369462433 };
// std::vector<double> activities_expected = { 0.044554683389757696, 0.015813166022685163, 0.015863953902810535 };
// std::vector<double> reluctances_expected = { 1.0, 1.0, 2.3 };

// REQUIRE( agents.size() == 3 );

// for( size_t i = 0; i < agents.size(); i++ )
// {
// fmt::print( "{}", i );
// REQUIRE_THAT( agents[i].data.opinion, Catch::Matchers::WithinAbs( opinions_expected[i], 1e-16 ) );
// REQUIRE_THAT( agents[i].data.activity, Catch::Matchers::WithinAbs( activities_expected[i], 1e-16 ) );
// REQUIRE_THAT( agents[i].data.reluctance, Catch::Matchers::WithinAbs( reluctances_expected[i], 1e-16 ) );
// }
// }

0 comments on commit 25c78d4

Please sign in to comment.