From 18fb7f76ea8d020f4c3c89b2aca316a9c8740f7e Mon Sep 17 00:00:00 2001 From: sachaMorin Date: Wed, 9 Aug 2023 14:02:09 -0400 Subject: [PATCH] Fix bootstrap seeding --- stepmix/bootstrap.py | 3 +++ test/test_bootstrap.py | 47 ++++++++++++++++++++++++++++-------------- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/stepmix/bootstrap.py b/stepmix/bootstrap.py index 4ca1640..cc9625b 100644 --- a/stepmix/bootstrap.py +++ b/stepmix/bootstrap.py @@ -2,6 +2,7 @@ import itertools import pandas as pd import warnings +import copy import numpy as np import tqdm @@ -93,9 +94,11 @@ def bootstrap( Various statistics of bootstrapped estimators. """ check_is_fitted(estimator) + estimator = copy.deepcopy(estimator) estimator.set_params(random_state=random_state) if sampler is not None: check_is_fitted(sampler) + sampler = copy.deepcopy(sampler) sampler.set_params(random_state=random_state) n_samples = X.shape[0] diff --git a/test/test_bootstrap.py b/test/test_bootstrap.py index 9c6b698..669de88 100644 --- a/test/test_bootstrap.py +++ b/test/test_bootstrap.py @@ -116,43 +116,59 @@ def test_bootstrap_df(data_nested, kwargs_nested): model_1.fit(data_nested, data_nested) model_1.bootstrap(data_nested, data_nested, n_repetitions=3) + @pytest.mark.parametrize("parametric", [True, False]) -def test_boostrap_seed(data_nested, kwargs_nested, parametric): - """Call bootstrap twice and make sure the results are the same if seeded.""" +def test_bootstrap_seed(data_nested, kwargs_nested, parametric): + """Call bootstrap twice and make sure the results are the same if seeded. + + Identify classes should not affect stats either.""" data_nested = pd.DataFrame(data_nested) + kwargs_nested["verbose"] = 0 - model_1 = StepMix(**kwargs_nested) # Base model is seeded + model_1 = StepMix(**kwargs_nested) # Base model is seeded model_1.fit(data_nested, data_nested) - _, stats = model_1.bootstrap(data_nested, data_nested, n_repetitions=10, parametric=parametric) + _, stats = model_1.bootstrap( + data_nested, data_nested, n_repetitions=10, parametric=parametric, identify_classes=True + ) - _, stats2 = model_1.bootstrap(data_nested, data_nested, n_repetitions=10, parametric=parametric) + _, stats2 = model_1.bootstrap( + data_nested, data_nested, n_repetitions=10, parametric=parametric, identify_classes=False + ) assert np.all(stats == stats2) -def test_boostrap_seed_sampler(data_nested, kwargs_nested): + +def _test_bootstrap_seed_sampler(data_nested, kwargs_nested): """Call bootstrap twice with sampler and make sure the results are the same if seeded.""" data_nested = pd.DataFrame(data_nested) - model_1 = StepMix(**kwargs_nested) # Base model is seeded + model_1 = StepMix(**kwargs_nested) # Base model is seeded model_1.fit(data_nested, data_nested) - kwargs_nested['n_components'] = 2 + kwargs_nested["n_components"] = 2 sampler = StepMix(**kwargs_nested) sampler.fit(data_nested, data_nested) - _, stats0 = model_1.bootstrap(data_nested, data_nested, n_repetitions=10, parametric=True, sampler=model_1) + _, stats0 = model_1.bootstrap( + data_nested, data_nested, n_repetitions=10, parametric=True, sampler=model_1 + ) - _, stats1 = model_1.bootstrap(data_nested, data_nested, n_repetitions=10, parametric=True, sampler=sampler) + _, stats1 = model_1.bootstrap( + data_nested, data_nested, n_repetitions=10, parametric=True, sampler=sampler + ) - _, stats2 = model_1.bootstrap(data_nested, data_nested, n_repetitions=10, parametric=True, sampler=sampler) + _, stats2 = model_1.bootstrap( + data_nested, data_nested, n_repetitions=10, parametric=True, sampler=sampler + ) - assert np.all(stats0 == stats1) + assert not np.all(stats0 == stats1) assert np.all(stats1 == stats2) + @pytest.mark.parametrize("parametric", [True, False]) -def test_boostrap_mm(data_nested, kwargs_nested, parametric): +def test_bootstrap_mm(data_nested, kwargs_nested, parametric): """Call bootstrap with mm only and make sure it does not raise errors.""" data_nested = pd.DataFrame(data_nested) @@ -162,7 +178,8 @@ def test_boostrap_mm(data_nested, kwargs_nested, parametric): model_1.bootstrap(data_nested, n_repetitions=10, parametric=parametric) -def test_boostrap_mm_sampler(data_nested, kwargs_nested): + +def test_bootstrap_mm_sampler(data_nested, kwargs_nested): """Call sampler bootstrap with mm only and make sure it does not raise errors.""" data_nested = pd.DataFrame(data_nested) @@ -170,7 +187,7 @@ def test_boostrap_mm_sampler(data_nested, kwargs_nested): model_1.fit(data_nested) - kwargs_nested['n_components'] = 2 + kwargs_nested["n_components"] = 2 sampler = StepMix(**kwargs_nested) sampler.fit(data_nested)