Skip to content

Commit

Permalink
Expose parametric in bootstrap_stats and improve relevant test
Browse files Browse the repository at this point in the history
  • Loading branch information
sachaMorin committed Feb 15, 2024
1 parent e6cdbb4 commit 01e7cff
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
12 changes: 10 additions & 2 deletions stepmix/stepmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,13 @@ def bootstrap(
)

def bootstrap_stats(
self, X, Y=None, n_repetitions=1000, sample_weight=None, progress_bar=True
self,
X,
Y=None,
n_repetitions=1000,
sample_weight=None,
parametric=False,
progress_bar=True,
):
"""Non-parametric boostrap of StepMix estimator. Obtain boostrapped parameters and some statistics
(mean and standard deviation).
Expand All @@ -1164,6 +1170,8 @@ def bootstrap_stats(
sample_weight : array-like of shape(n_samples,), default=None
n_repetitions: int
Number of repetitions to fit.
parametric: bool, default=False
Use parametric bootstrap instead of non-parametric. Data will be generated by sampling the estimator.
progress_bar : bool, default=True
Display a tqdm progress bar for repetitions.
Returns
Expand All @@ -1185,7 +1193,7 @@ def bootstrap_stats(
Y=Y,
n_repetitions=n_repetitions,
sample_weight=sample_weight,
parametric=False,
parametric=parametric,
progress_bar=progress_bar,
)

Expand Down
13 changes: 7 additions & 6 deletions test/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def test_nested_permutation(data_nested, kwargs_nested):
"ignore::sklearn.exceptions.ConvergenceWarning"
) # Ignore convergence warnings for same reason
@pytest.mark.parametrize("model", EMISSION_DICT.keys())
def test_bootstrap(data, kwargs, model):
@pytest.mark.parametrize("parametric", [False, True])
def test_bootstrap(data, kwargs, model, parametric):
"""Call the boostrap procedure on all models and make sure they don't raise errors.
The data may not make sense for the model. We therefore do not test a particular output here."""
Expand All @@ -100,12 +101,12 @@ def test_bootstrap(data, kwargs, model):
model_1 = StepMix(n_steps=1, **kwargs)
model_1.fit(X, Y)

if model != "covariate":
model_1.bootstrap_stats(X, Y, n_repetitions=3)
else:
# Should raise error. Can't sample from a covariate model
if parametric and model == "covariate":
# Parametric bootstrap is not implemented for covariate model
with pytest.raises(NotImplementedError) as e_info:
model_1.bootstrap_stats(X, Y, n_repetitions=3)
model_1.bootstrap_stats(X, Y, parametric=parametric, n_repetitions=3)
else:
model_1.bootstrap_stats(X, Y, parametric=parametric, n_repetitions=3)


def test_nested_bootstrap(data_nested, kwargs_nested):
Expand Down

0 comments on commit 01e7cff

Please sign in to comment.