From 7022a56181c68af5c5f2dab33d056e55d36ba107 Mon Sep 17 00:00:00 2001 From: Viktor Gal Date: Wed, 23 Oct 2019 14:27:37 +0200 Subject: [PATCH] seeding fix for xvalmmd fix #4783 --- .../internals/mmd/CrossValidationMMD.h | 11 +++-------- .../internals/CrossValidationMMD_unittest.cc | 9 ++++++--- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/shogun/statistical_testing/internals/mmd/CrossValidationMMD.h b/src/shogun/statistical_testing/internals/mmd/CrossValidationMMD.h index d9084ffa273..71a6c0934e2 100644 --- a/src/shogun/statistical_testing/internals/mmd/CrossValidationMMD.h +++ b/src/shogun/statistical_testing/internals/mmd/CrossValidationMMD.h @@ -195,16 +195,11 @@ struct CrossValidationMMD : PermutationMMD SGVector dummy_labels_x(m_n_x); SGVector dummy_labels_y(m_n_y); - auto instance_x=new CCrossValidationSplitting(new CBinaryLabels(dummy_labels_x), m_num_folds); - auto instance_y=new CCrossValidationSplitting(new CBinaryLabels(dummy_labels_y), m_num_folds); - random::seed(instance_x, prng); - random::seed(instance_y, prng); - - m_kfold_x=unique_ptr(instance_x); - m_kfold_y=unique_ptr(instance_y); + m_kfold_x=std::make_unique(new CBinaryLabels(dummy_labels_x), m_num_folds); + m_kfold_y=std::make_unique(new CBinaryLabels(dummy_labels_y), m_num_folds); random::seed(m_kfold_x.get(), prng); random::seed(m_kfold_y.get(), prng); - + m_stack=unique_ptr(new CSubsetStack()); const index_t size=m_n_x+m_n_y; diff --git a/tests/unit/statistical_testing/internals/CrossValidationMMD_unittest.cc b/tests/unit/statistical_testing/internals/CrossValidationMMD_unittest.cc index ec7f5fe9e15..9b62d700811 100644 --- a/tests/unit/statistical_testing/internals/CrossValidationMMD_unittest.cc +++ b/tests/unit/statistical_testing/internals/CrossValidationMMD_unittest.cc @@ -107,6 +107,7 @@ TEST(CrossValidationMMD, biased_full) kfold_p->put("seed", seed); kfold_q->put("seed", seed); + std::mt19937_64 permPRNG(seed); auto permutation_mmd=PermutationMMD(); permutation_mmd.m_stype=stype; permutation_mmd.m_num_null_samples=num_null_samples; @@ -134,7 +135,7 @@ TEST(CrossValidationMMD, biased_full) (feats_p->create_merged_copy(feats_q)); kernel->init(current_merged_feats, current_merged_feats); - auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix(), prng); + auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix(), permPRNG); EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_valueput("seed", seed); kfold_q->put("seed", seed); + std::mt19937_64 permPRNG(seed); auto permutation_mmd=PermutationMMD(); permutation_mmd.m_stype=stype; permutation_mmd.m_num_null_samples=num_null_samples; @@ -233,7 +235,7 @@ TEST(CrossValidationMMD, unbiased_full) (feats_p->create_merged_copy(feats_q)); kernel->init(current_merged_feats, current_merged_feats); - auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix(), prng); + auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix(), permPRNG); EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_valueput("seed", seed); kfold_q->put("seed", seed); + std::mt19937_64 permPRNG(seed); auto permutation_mmd=PermutationMMD(); permutation_mmd.m_stype=stype; permutation_mmd.m_num_null_samples=num_null_samples; @@ -333,7 +336,7 @@ TEST(CrossValidationMMD, unbiased_incomplete) (feats_p->create_merged_copy(feats_q)); kernel->init(current_merged_feats, current_merged_feats); - auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix(), prng); + auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix(), permPRNG); EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_value