From 01e7cff4305596a4403c28ee73d145ef2368db0e Mon Sep 17 00:00:00 2001 From: sachaMorin Date: Thu, 15 Feb 2024 10:19:49 -0500 Subject: [PATCH] Expose parametric in bootstrap_stats and improve relevant test --- stepmix/stepmix.py | 12 ++++++++++-- test/test_bootstrap.py | 13 +++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/stepmix/stepmix.py b/stepmix/stepmix.py index d71c5b9..b42d1ad 100644 --- a/stepmix/stepmix.py +++ b/stepmix/stepmix.py @@ -1150,7 +1150,13 @@ def bootstrap( ) def bootstrap_stats( - self, X, Y=None, n_repetitions=1000, sample_weight=None, progress_bar=True + self, + X, + Y=None, + n_repetitions=1000, + sample_weight=None, + parametric=False, + progress_bar=True, ): """Non-parametric boostrap of StepMix estimator. Obtain boostrapped parameters and some statistics (mean and standard deviation). @@ -1164,6 +1170,8 @@ def bootstrap_stats( sample_weight : array-like of shape(n_samples,), default=None n_repetitions: int Number of repetitions to fit. + parametric: bool, default=False + Use parametric bootstrap instead of non-parametric. Data will be generated by sampling the estimator. progress_bar : bool, default=True Display a tqdm progress bar for repetitions. Returns @@ -1185,7 +1193,7 @@ def bootstrap_stats( Y=Y, n_repetitions=n_repetitions, sample_weight=sample_weight, - parametric=False, + parametric=parametric, progress_bar=progress_bar, ) diff --git a/test/test_bootstrap.py b/test/test_bootstrap.py index 289476d..9199803 100644 --- a/test/test_bootstrap.py +++ b/test/test_bootstrap.py @@ -83,7 +83,8 @@ def test_nested_permutation(data_nested, kwargs_nested): "ignore::sklearn.exceptions.ConvergenceWarning" ) # Ignore convergence warnings for same reason @pytest.mark.parametrize("model", EMISSION_DICT.keys()) -def test_bootstrap(data, kwargs, model): +@pytest.mark.parametrize("parametric", [False, True]) +def test_bootstrap(data, kwargs, model, parametric): """Call the boostrap procedure on all models and make sure they don't raise errors. The data may not make sense for the model. We therefore do not test a particular output here.""" @@ -100,12 +101,12 @@ def test_bootstrap(data, kwargs, model): model_1 = StepMix(n_steps=1, **kwargs) model_1.fit(X, Y) - if model != "covariate": - model_1.bootstrap_stats(X, Y, n_repetitions=3) - else: - # Should raise error. Can't sample from a covariate model + if parametric and model == "covariate": + # Parametric bootstrap is not implemented for covariate model with pytest.raises(NotImplementedError) as e_info: - model_1.bootstrap_stats(X, Y, n_repetitions=3) + model_1.bootstrap_stats(X, Y, parametric=parametric, n_repetitions=3) + else: + model_1.bootstrap_stats(X, Y, parametric=parametric, n_repetitions=3) def test_nested_bootstrap(data_nested, kwargs_nested):