Skip to content

Commit

Permalink
Merge pull request #138 from scikit-learn-contrib/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JulienRoussel77 authored Apr 17, 2024
2 parents 5683e66 + 5559c85 commit 3fba238
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3.12.0
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
Expand Down
6 changes: 6 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
History
=======

0.1.5 (2024-04-17)
------------------

* CICD now relies on Node.js 20
* New tests for comparator.py and data.py

0.1.4 (2024-04-15)
------------------

Expand Down
26 changes: 13 additions & 13 deletions qolmat/imputations/em_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray:
return X_final


def min_diff_Linf(list_params: List[NDArray], n_steps: int, order: int = 1) -> float:
def max_diff_Linf(list_params: List[NDArray], n_steps: int, order: int = 1) -> float:
"""Computes the maximal L infinity norm between the `n_steps` last elements spaced by order.
Used to compute the stop criterion.
Expand Down Expand Up @@ -762,8 +762,8 @@ def _check_convergence(self) -> bool:
if n_iter < 3:
return False

min_diff_means1 = min_diff_Linf(list_covs, n_steps=1)
min_diff_covs1 = min_diff_Linf(list_means, n_steps=1)
min_diff_means1 = max_diff_Linf(list_means, n_steps=1)
min_diff_covs1 = max_diff_Linf(list_covs, n_steps=1)
min_diff_reached = min_diff_means1 < self.tolerance and min_diff_covs1 < self.tolerance

if min_diff_reached:
Expand All @@ -772,16 +772,16 @@ def _check_convergence(self) -> bool:
if n_iter < 7:
return False

min_diff_means5 = min_diff_Linf(list_covs, n_steps=5)
min_diff_covs5 = min_diff_Linf(list_means, n_steps=5)
min_diff_means5 = max_diff_Linf(list_means, n_steps=5)
min_diff_covs5 = max_diff_Linf(list_covs, n_steps=5)

min_diff_stable = (
min_diff_means5 < self.stagnation_threshold
and min_diff_covs5 < self.stagnation_threshold
)

min_diff_loglik5_ord1 = min_diff_Linf(list_logliks, n_steps=5)
min_diff_loglik5_ord2 = min_diff_Linf(list_logliks, n_steps=5, order=2)
min_diff_loglik5_ord1 = max_diff_Linf(list_logliks, n_steps=5)
min_diff_loglik5_ord2 = max_diff_Linf(list_logliks, n_steps=5, order=2)
max_loglik = (min_diff_loglik5_ord1 < self.stagnation_loglik) or (
min_diff_loglik5_ord2 < self.stagnation_loglik
)
Expand Down Expand Up @@ -1105,8 +1105,8 @@ def _check_convergence(self) -> bool:
if n_iter < 3:
return False

min_diff_B1 = min_diff_Linf(list_B, n_steps=1)
min_diff_S1 = min_diff_Linf(list_S, n_steps=1)
min_diff_B1 = max_diff_Linf(list_B, n_steps=1)
min_diff_S1 = max_diff_Linf(list_S, n_steps=1)
min_diff_reached = min_diff_B1 < self.tolerance and min_diff_S1 < self.tolerance

if min_diff_reached:
Expand All @@ -1115,14 +1115,14 @@ def _check_convergence(self) -> bool:
if n_iter < 7:
return False

min_diff_B5 = min_diff_Linf(list_B, n_steps=5)
min_diff_S5 = min_diff_Linf(list_S, n_steps=5)
min_diff_B5 = max_diff_Linf(list_B, n_steps=5)
min_diff_S5 = max_diff_Linf(list_S, n_steps=5)
min_diff_stable = (
min_diff_B5 < self.stagnation_threshold and min_diff_S5 < self.stagnation_threshold
)

max_loglik5_ord1 = min_diff_Linf(list_logliks, n_steps=5, order=1)
max_loglik5_ord2 = min_diff_Linf(list_logliks, n_steps=5, order=2)
max_loglik5_ord1 = max_diff_Linf(list_logliks, n_steps=5, order=1)
max_loglik5_ord2 = max_diff_Linf(list_logliks, n_steps=5, order=2)
max_loglik = (max_loglik5_ord1 < self.stagnation_loglik) or (
max_loglik5_ord2 < self.stagnation_loglik
)
Expand Down

0 comments on commit 3fba238

Please sign in to comment.