From eccb46d2e25dd2275fc54982854d2c59b1ec51b3 Mon Sep 17 00:00:00 2001 From: Julien Roussel <3178729-JulienRoussel77@users.noreply.gitlab.com> Date: Thu, 7 Mar 2024 16:46:00 +0100 Subject: [PATCH] history updated --- HISTORY.rst | 11 + docs/imputers.rst | 2 +- examples/benchmark.md | 153 +++++++------ qolmat/benchmark/metrics.py | 2 +- qolmat/imputations/em_sampler.py | 298 ++++++++++++++++++++----- qolmat/imputations/imputers.py | 61 +++-- qolmat/imputations/imputers_pytorch.py | 3 +- qolmat/imputations/rpca/rpca_noisy.py | 221 ++++++++++-------- qolmat/imputations/rpca/rpca_pcp.py | 8 +- qolmat/imputations/rpca/rpca_utils.py | 3 +- qolmat/utils/data.py | 20 +- qolmat/utils/exceptions.py | 9 + qolmat/utils/plot.py | 72 ++++-- qolmat/utils/utils.py | 15 ++ tests/imputations/test_em_sampler.py | 117 +++++++--- tests/imputations/test_imputers.py | 47 +--- tests/utils/test_data.py | 12 +- tests/utils/test_plot.py | 4 +- 18 files changed, 726 insertions(+), 332 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index b5ca356f..64b4fbed 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,17 @@ History ======= +0.1.3 (2024-03-07) +------------------ + +* RPCA algorithms now start with a normalizing scaler +* The EM algorithms now include a gradient projection step to be more robust to colinearity +* The EM algorithm based on the Gaussian model is now initialized using a robust estimation of the covariance matrix +* A bug in the EM algorithm has been patched: the normalizing matrix gamma was creating a sampling biais +* Speed up of the EM algorithm likelihood maximization, using the conjugate gradient method +* The ImputeRegressor class now handles the nans by `row` by default +* The metric `frechet` was not correctly called and has been patched + 0.1.2 (2024-02-28) ------------------ diff --git a/docs/imputers.rst b/docs/imputers.rst index b66898a8..ad95b6b9 100644 --- a/docs/imputers.rst +++ b/docs/imputers.rst @@ -41,7 +41,7 @@ See the :class:`~qolmat.imputations.imputers.ImputerRpcaPcp` class for implement The class :class:`RPCANoisy` implements an recommanded improved version, which relies on a decomposition :math:`\mathbf{D} = \mathbf{M} + \mathbf{A} + \mathbf{E}`. The additionnal term encodes a Gaussian noise and makes the numerical convergence more reliable. This class also implements a time-consistency penalization for time series, parametrized by the :math:`\eta_k`and :math:`H_k`. By defining :math:`\Vert \mathbf{MH_k} \Vert_p` is either :math:`\Vert \mathbf{MH_k} \Vert_1` or :math:`\Vert \mathbf{MH_k} \Vert_F^2`, the optimisation problem is the following .. math:: - \text{min}_{\mathbf{M, A} \in \mathbb{R}^{m \times n}} \quad \Vert P_{\Omega} (\mathbf{D}-\mathbf{M}-\mathbf{A}) \Vert_F^2 + \tau \Vert \mathbf{M} \Vert_* + \lambda \Vert \mathbf{A} \Vert_1 + \sum_{k=1}^K \eta_k \Vert \mathbf{M H_k} \Vert_p + \text{min}_{\mathbf{M, A} \in \mathbb{R}^{m \times n}} \quad \frac 1 2 \Vert P_{\Omega} (\mathbf{D}-\mathbf{M}-\mathbf{A}) \Vert_F^2 + \tau \Vert \mathbf{M} \Vert_* + \lambda \Vert \mathbf{A} \Vert_1 + \sum_{k=1}^K \eta_k \Vert \mathbf{M H_k} \Vert_p with :math:`\mathbf{E} = \mathbf{D} - \mathbf{M} - \mathbf{A}`. See the :class:`~qolmat.imputations.imputers.ImputerRpcaNoisy` class for implementation details. diff --git a/examples/benchmark.md b/examples/benchmark.md index af92b16c..a4f16135 100644 --- a/examples/benchmark.md +++ b/examples/benchmark.md @@ -8,9 +8,9 @@ jupyter: format_version: '1.3' jupytext_version: 1.14.4 kernelspec: - display_name: env_qolmat + display_name: env_qolmat_dev language: python - name: env_qolmat + name: env_qolmat_dev --- **This notebook aims to present the Qolmat repo through an example of a multivariate time series. @@ -28,6 +28,8 @@ import warnings %reload_ext autoreload %autoreload 2 +from IPython.display import Image + import pandas as pd from datetime import datetime import numpy as np @@ -82,12 +84,12 @@ n_cols = len(cols_to_impute) ``` ```python -fig = plt.figure(figsize=(10 * n_stations, 3 * n_cols)) +fig = plt.figure(figsize=(20 * n_stations, 6 * n_cols)) for i_station, (station, df) in enumerate(df_data.groupby("station")): df_station = df_data.loc[station] for i_col, col in enumerate(cols_to_impute): fig.add_subplot(n_cols, n_stations, i_col * n_stations + i_station + 1) - plt.plot(df_station[col], '.', label=station) + plt.plot(df_station[col], label=station) # break plt.ylabel(col) plt.xticks(rotation=15) @@ -127,7 +129,7 @@ imputer_spline = imputers.ImputerInterpolation(groups=("station",), method="spli imputer_shuffle = imputers.ImputerShuffle(groups=("station",)) imputer_residuals = imputers.ImputerResiduals(groups=("station",), period=365, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear") -imputer_rpca = imputers.ImputerRpcaNoisy(groups=("station",), columnwise=False, max_iterations=500, tau=2, lam=0.05) +imputer_rpca = imputers.ImputerRpcaNoisy(groups=("station",), columnwise=False, max_iterations=500, tau=.01, lam=5, rank=1) imputer_rpca_opti = imputers.ImputerRpcaNoisy(groups=("station",), columnwise=False, max_iterations=256) dict_config_opti["RPCA_opti"] = { "tau": ho.hp.uniform("tau", low=.5, high=5), @@ -141,9 +143,9 @@ dict_config_opti["RPCA_opticw"] = { "lam/PRES": ho.hp.uniform("lam/PRES", low=.1, high=1), } -imputer_ou = imputers.ImputerEM(groups=("station",), model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3) -imputer_tsou = imputers.ImputerEM(groups=("station",), model="VAR", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3, p=1) -imputer_tsmle = imputers.ImputerEM(groups=("station",), model="VAR", method="mle", max_iter_em=100, n_iter_ou=15, dt=1e-3, p=1) +imputer_normal_sample = imputers.ImputerEM(groups=("station",), model="multinormal", method="sample", max_iter_em=8, n_iter_ou=128, dt=4e-2) +imputer_var_sample = imputers.ImputerEM(groups=("station",), model="VAR", method="sample", max_iter_em=8, n_iter_ou=128, dt=4e-2, p=1) +imputer_var_max = imputers.ImputerEM(groups=("station",), model="VAR", method="mle", max_iter_em=8, n_iter_ou=128, dt=4e-2, p=1) imputer_knn = imputers.ImputerKNN(groups=("station",), n_neighbors=10) imputer_mice = imputers.ImputerMICE(groups=("station",), estimator=LinearRegression(), sample_posterior=False, max_iter=100) @@ -163,17 +165,17 @@ dict_imputers = { # "spline": imputer_spline, # "shuffle": imputer_shuffle, "residuals": imputer_residuals, - # "OU": imputer_ou, - "TSOU": imputer_tsou, - "TSMLE": imputer_tsmle, + "Normal_sample": imputer_normal_sample, + "VAR_sample": imputer_var_sample, + "VAR_max": imputer_var_max, "RPCA": imputer_rpca, # "RPCA_opti": imputer_rpca, # "RPCA_opticw": imputer_rpca_opti2, # "locf": imputer_locf, # "nocb": imputer_nocb, # "knn": imputer_knn, - "ols": imputer_regressor, - "mice_ols": imputer_mice, + "OLS": imputer_regressor, + "MICE_OLS": imputer_mice, } n_imputers = len(dict_imputers) ``` @@ -181,7 +183,7 @@ n_imputers = len(dict_imputers) In order to compare the methods, we $i)$ artificially create missing data (for missing data mechanisms, see the docs); $ii)$ then impute it using the different methods chosen and $iii)$ calculate the reconstruction error. These three steps are repeated a number of times equal to `n_splits`. For each method, we calculate the average error and compare the final errors.

- +

@@ -190,14 +192,14 @@ Concretely, the comparator takes as input a dataframe to impute, a proportion of Note these metrics compute reconstruction errors; it tells nothing about the distances between the "true" and "imputed" distributions. -```python -metrics = ["mae", "wmape", "KL_columnwise", "ks_test", "dist_corr_pattern"] +```python tags=[] +metrics = ["mae", "wmape", "KL_columnwise", "frechet"] comparison = comparator.Comparator( dict_imputers, cols_to_impute, generator_holes = generator_holes, metrics=metrics, - max_evals=10, + max_evals=2, dict_config_opti=dict_config_opti, ) results = comparison.compare(df_data) @@ -220,9 +222,9 @@ plt.show() ### **III. Comparison of methods** -We now run just one time each algorithm on the initial corrupted dataframe and compare the different performances through multiple analysis. +We now run just one time each algorithm on the initial corrupted dataframe and visualize the different imputations. -```python +```python tags=[] df_plot = df_data[cols_to_impute] ``` @@ -233,12 +235,17 @@ dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.i ```python station = df_plot.index.get_level_values("station")[0] df_station = df_plot.loc[station] +# dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()} dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()} ``` Let's look at the imputations. When the data is missing at random, imputation is easier. Missing block are more challenging. +```python +dfs_imputed_station["VAR_max"] +``` + ```python for col in cols_to_impute: fig, ax = plt.subplots(figsize=(10, 3)) @@ -270,21 +277,21 @@ i_plot = 1 for i_col, col in enumerate(cols_to_impute): for name_imputer, df_imp in dfs_imputed_station.items(): - fig.add_subplot(n_columns, n_imputers, i_plot) + ax = fig.add_subplot(n_columns, n_imputers, i_plot) values_orig = df_station[col] values_imp = df_imp[col].copy() values_imp[values_orig.notna()] = np.nan - plt.plot(values_imp, marker="o", color=tab10(0), label=name_imputer, alpha=1) + plt.plot(values_imp, marker="o", color=tab10(0), label="imputation", alpha=1) plt.plot(values_orig, color='black', marker="o", label="original") plt.ylabel(col, fontsize=16) - if i_plot % n_columns == 1: - plt.legend(loc=[1, 0], fontsize=18) + if i_plot % n_imputers == 0: + plt.legend(loc="lower right", fontsize=18) plt.xticks(rotation=15) if i_col == 0: plt.title(name_imputer) if i_col != n_columns - 1: - plt.xticks([], []) + ax.set_xticklabels([]) loc = plticker.MultipleLocator(base=2*365) ax.xaxis.set_major_locator(loc) ax.tick_params(axis='both', which='major') @@ -297,7 +304,7 @@ plt.show() ## (Optional) Deep Learning Model -In this section, we present an MLP model of data imputation using Keras, which can be installed using a "pip install pytorch". +In this section, we present an MLP model of data imputation using PyTorch, which can be installed using a "pip install qolmat[pytorch]". ```python from qolmat.imputations import imputers_pytorch @@ -308,17 +315,6 @@ except ModuleNotFoundError: raise PyTorchExtraNotInstalled ``` -For the MLP model, we work on a dataset that corresponds to weather data with missing values. We add missing MCAR values on the features "TEMP", "PRES" and other features with NaN values. The goal is impute the missing values for the features "TEMP" and "PRES" by a Deep Learning method. We add features to take into account the seasonality of the data set and a feature for the station name - -```python -df = data.get_data("Beijing") -cols_to_impute = ["TEMP", "PRES"] -cols_with_nans = list(df.columns[df.isna().any()]) -df_data = data.add_datetime_features(df) -df_data[cols_with_nans + cols_to_impute] = data.add_holes(pd.DataFrame(df_data[cols_with_nans + cols_to_impute]), ratio_masked=.1, mean_size=120) -df_data -``` - For the example, we use a simple MLP model with 3 layers of neurons. Then we train the model without taking a group on the stations @@ -340,49 +336,75 @@ plt.show() ``` ```python -# estimator = nn.Sequential( -# nn.Linear(np.sum(df_data.isna().sum()==0), 256), -# nn.ReLU(), -# nn.Linear(256, 128), -# nn.ReLU(), -# nn.Linear(128, 64), -# nn.ReLU(), -# nn.Linear(64, 1) -# ) -estimator = imputers_pytorch.build_mlp(input_dim=np.sum(df_data.isna().sum()==0), list_num_neurons=[256,128,64]) -encoder, decoder = imputers_pytorch.build_autoencoder(input_dim=df_data.values.shape[1],latent_dim=4, output_dim=df_data.values.shape[1], list_num_neurons=[4*4, 2*4]) +n_variables = len(cols_to_impute) + +estimator = imputers_pytorch.build_mlp(input_dim=n_variables-1, list_num_neurons=[256,128,64]) +encoder, decoder = imputers_pytorch.build_autoencoder(input_dim=n_variables,latent_dim=4, output_dim=n_variables, list_num_neurons=[4*4, 2*4]) ``` ```python -dict_imputers["MLP"] = imputer_mlp = imputers_pytorch.ImputerRegressorPyTorch(estimator=estimator, groups=('station',), handler_nan = "column", epochs=500) +dict_imputers["MLP"] = imputer_mlp = imputers_pytorch.ImputerRegressorPyTorch(estimator=estimator, groups=('station',), epochs=500) dict_imputers["Autoencoder"] = imputer_autoencoder = imputers_pytorch.ImputerAutoencoder(encoder, decoder, max_iterations=100, epochs=100) dict_imputers["Diffusion"] = imputer_diffusion = imputers_pytorch.ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=100, batch_size=100) ``` We can re-run the imputation model benchmark as before. +```python +comparison = comparator.Comparator( + dict_imputers, + cols_to_impute, + generator_holes = generator_holes, + metrics=metrics, + max_evals=2, + dict_config_opti=dict_config_opti, +) +``` + ```python tags=[] generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=3, groups=('station',), subset=cols_to_impute, ratio_masked=ratio_masked) comparison = comparator.Comparator( dict_imputers, - selected_columns = df_data.columns, + cols_to_impute, generator_holes = generator_holes, - metrics=["mae", "wmape", "KL_columnwise", "ks_test"], - max_evals=10, + metrics=metrics, + max_evals=2, dict_config_opti=dict_config_opti, ) results = comparison.compare(df_data) results.style.highlight_min(color="green", axis=1) ``` +```python +n_metrics = len(metrics) +fig = plt.figure(figsize=(24, 4 * n_metrics)) +for i, metric in enumerate(metrics): + fig.add_subplot(n_metrics, 1, i + 1) + df = results.loc[metric] + plot.multibar(df, decimals=2) + plt.ylabel(metric) + +#plt.savefig("figures/imputations_benchmark_errors.png") +plt.show() +``` + ```python tags=[] -df_plot = df_data +df_plot = df_data[cols_to_impute] +``` + +```python dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()} +``` + +```python station = df_plot.index.get_level_values("station")[0] df_station = df_plot.loc[station] dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()} ``` -```python tags=[] +Let's look at the imputations. +When the data is missing at random, imputation is easier. Missing block are more challenging. + +```python for col in cols_to_impute: fig, ax = plt.subplots(figsize=(10, 3)) values_orig = df_station[col] @@ -399,39 +421,42 @@ for col in cols_to_impute: ax.xaxis.set_major_locator(loc) ax.tick_params(axis='both', which='major', labelsize=17) plt.show() + ``` ```python -n_columns = len(df_plot.columns) +# plot.plot_imputations(df_station, dfs_imputed_station) + +n_columns = len(cols_to_impute) n_imputers = len(dict_imputers) -fig = plt.figure(figsize=(8 * n_imputers, 6 * n_columns)) +fig = plt.figure(figsize=(12 * n_imputers, 4 * n_columns)) i_plot = 1 -for i_col, col in enumerate(df_plot): +for i_col, col in enumerate(cols_to_impute): for name_imputer, df_imp in dfs_imputed_station.items(): - fig.add_subplot(n_columns, n_imputers, i_plot) + ax = fig.add_subplot(n_columns, n_imputers, i_plot) values_orig = df_station[col] - plt.plot(values_orig, ".", color='black', label="original") - values_imp = df_imp[col].copy() values_imp[values_orig.notna()] = np.nan - plt.plot(values_imp, ".", color=tab10(0), label=name_imputer, alpha=1) + plt.plot(values_imp, marker="o", color=tab10(0), label="imputation", alpha=1) + plt.plot(values_orig, color='black', marker="o", label="original") plt.ylabel(col, fontsize=16) - if i_plot % n_columns == 1: - plt.legend(loc=[1, 0], fontsize=18) + if i_plot % n_imputers == 0: + plt.legend(loc="lower right", fontsize=18) plt.xticks(rotation=15) if i_col == 0: plt.title(name_imputer) if i_col != n_columns - 1: - plt.xticks([], []) + ax.set_xticklabels([]) loc = plticker.MultipleLocator(base=2*365) ax.xaxis.set_major_locator(loc) ax.tick_params(axis='both', which='major') i_plot += 1 -plt.savefig("figures/imputations_benchmark.png") + plt.show() + ``` ## Covariance diff --git a/qolmat/benchmark/metrics.py b/qolmat/benchmark/metrics.py index 802315b6..43f76b68 100644 --- a/qolmat/benchmark/metrics.py +++ b/qolmat/benchmark/metrics.py @@ -1011,7 +1011,7 @@ def get_metric(name: str) -> Callable: "ks_test": kolmogorov_smirnov_test, "correlation_diff": mean_difference_correlation_matrix_numerical_features, "energy": sum_energy_distances, - "frechet": frechet_distance, + "frechet": frechet_distance_pattern, "dist_corr_pattern": partial( pattern_based_weighted_mean_metric, metric=distance_anticorr, diff --git a/qolmat/imputations/em_sampler.py b/qolmat/imputations/em_sampler.py index 707991a6..93f577f1 100644 --- a/qolmat/imputations/em_sampler.py +++ b/qolmat/imputations/em_sampler.py @@ -1,5 +1,6 @@ from abc import abstractmethod from typing import Dict, List, Literal, Union +import warnings import numpy as np from numpy.typing import NDArray @@ -11,6 +12,10 @@ from qolmat.utils import utils +from matplotlib import pyplot as plt + +from qolmat.utils.exceptions import IllConditioned + def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray: """ @@ -20,7 +25,7 @@ def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray: Parameters ---------- A : NDArray - Symmetrical matrix defining the quadratic optimization problem + Symmetrical matrix defining the quadratic minimization problem X : NDArray Array containing the values to optimize mask : NDArray @@ -35,21 +40,32 @@ def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray: X_temp = X[rows_imputed, :].copy() mask = mask[rows_imputed, :].copy() n_iter = mask.sum(axis=1).max() + n_rows, n_cols = X_temp.shape X_temp[mask] = 0 b = -X_temp @ A b[~mask] = 0 - xn, pn, rn = np.zeros(X_temp.shape), b, b # Initialisation + xn, pn, rn = np.zeros((n_rows, n_cols)), b, b # Initialisation + alphan = np.zeros(n_rows) + betan = np.zeros(n_rows) for n in range(n_iter + 2): # if np.max(np.sum(rn**2)) < tolerance : # Condition de sortie " usuelle " # X_temp[mask_isna] = xn[mask_isna] # return X_temp.transpose() Apn = pn @ A Apn[~mask] = 0 - alphan = np.sum(rn**2, axis=1) / np.sum(pn * Apn, axis=1) - alphan[np.isnan(alphan)] = 0 # we stop updating if convergence is reached for this date + numerator = np.sum(rn**2, axis=1) + denominator = np.sum(pn * Apn, axis=1) + not_converged = denominator != 0 + # we stop updating if convergence is reached for this row + alphan[not_converged] = numerator[not_converged] / denominator[not_converged] + xn, rnp1 = xn + pn * alphan[:, None], rn - Apn * alphan[:, None] - betan = np.sum(rnp1**2, axis=1) / np.sum(rn**2, axis=1) - betan[np.isnan(betan)] = 0 # we stop updating if convergence is reached for this date + numerator = np.sum(rnp1**2, axis=1) + denominator = np.sum(rn**2, axis=1) + not_converged = denominator != 0 + # we stop updating if convergence is reached for this row + betan[not_converged] = numerator[not_converged] / denominator[not_converged] + pn, rn = rnp1 + pn * betan[:, None], rnp1 X_temp[mask] = xn[mask] @@ -116,6 +132,8 @@ class EM(BaseEstimator, TransformerMixin): stagnation_loglik : float, optional Threshold below which an absolute difference of the log likelihood indicates the convergence of the parameters + min_std: float, optional + Threshold below which the initial data matrix is considered ill-conditioned period : int, optional Integer used to fold the temporal data periodically verbose : bool, optional @@ -134,6 +152,7 @@ def __init__( tolerance: float = 1e-4, stagnation_threshold: float = 5e-3, stagnation_loglik: float = 2, + min_std: float = 1e-6, period: int = 1, verbose: bool = False, ): @@ -151,10 +170,14 @@ def __init__( self.stagnation_threshold = stagnation_threshold self.stagnation_loglik = stagnation_loglik + self.min_std = min_std + self.dict_criteria_stop: Dict[str, List] = {} self.period = period self.verbose = verbose self.n_samples = n_samples + self.hash_fit = 0 + self.shape = (0, 0) def _check_convergence(self) -> bool: return False @@ -176,6 +199,18 @@ def fit_parameters(self, X: NDArray): self.update_parameters(X) self.combine_parameters() + def fit_parameters_with_missingness(self, X: NDArray): + """ + First estimation of the model parameters based on data with missing values. + + Parameters + ---------- + X : NDArray + Data matrix with missingness + """ + X_imp = self.init_imputation(X) + self.fit_parameters(X_imp) + def update_criteria_stop(self, X: NDArray): self.loglik = self.get_loglikelihood(X) @@ -190,9 +225,22 @@ def gradient_X_loglik( ) -> NDArray: return np.empty # type: ignore #noqa - def get_gamma(self) -> NDArray: - n_rows, n_cols = self.shape_original - return np.ones((1, n_cols)) + def get_gamma(self, n_cols: int) -> NDArray: + """ + Normalization matrix in the sampling process. + + Parameters + ---------- + n_cols : int + Number of variables in the data matrix + + Returns + ------- + NDArray + Gamma matrix + """ + # return np.ones((1, n_cols)) + return np.eye(n_cols) def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray: """Get the argmax of a posterior distribution using the BFGS algorithm. @@ -200,7 +248,7 @@ def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray: Parameters ---------- X : NDArray - Input numpy array. + Input numpy array without missingness mask_na : NDArray Boolean dataframe indicating which coefficients should be resampled, and are therefore the variables of the optimization @@ -214,22 +262,19 @@ def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray: def fun_obj(x): x_mat = X.copy() x_mat[mask_na] = x - return self.get_loglikelihood(x_mat) + return -self.get_loglikelihood(x_mat) def fun_jac(x): x_mat = X.copy() x_mat[mask_na] = x - grad_x = self.gradient_X_loglik(x_mat) - grad_x[~mask_na] = 0 + grad_x = -self.gradient_X_loglik(x_mat) + grad_x = grad_x[mask_na] return grad_x - res = spo.minimize(fun_obj, X[mask_na], jac=fun_jac) - - # for _ in range(1000): - # grad = self.gradient_X_loglik(X) - # grad[~mask_na] = 0 - # X += dt * grad + # the method BFGS is much slower, probabily not adapted to the high-dimension setting + res = spo.minimize(fun_obj, X[mask_na], jac=fun_jac, method="CG") x = res.x + X_sol = X.copy() X_sol[mask_na] = x return X_sol @@ -263,16 +308,17 @@ def _sample_ou( Sampled data matrix """ X_copy = X.copy() - n_variables, n_samples = X_copy.shape + n_rows, n_cols = X_copy.shape if estimate_params: self.reset_learned_parameters() X_init = X.copy() - gamma = self.get_gamma() + gamma = self.get_gamma(n_cols) sqrt_gamma = np.real(spl.sqrtm(gamma)) + for i in range(self.n_iter_ou): - noise = self.ampli * self.rng.normal(0, 1, size=(n_variables, n_samples)) - grad_X = self.gradient_X_loglik(X_copy) - X_copy += self.dt * grad_X @ gamma + np.sqrt(2 * self.dt) * noise @ sqrt_gamma + noise = self.ampli * self.rng.normal(0, 1, size=(n_rows, n_cols)) + grad_X = -self.gradient_X_loglik(X_copy) + X_copy += -self.dt * grad_X @ gamma + np.sqrt(2 * self.dt) * noise @ sqrt_gamma X_copy[~mask_na] = X_init[~mask_na] if estimate_params: self.update_parameters(X_copy) @@ -283,20 +329,27 @@ def fit_X(self, X: NDArray) -> None: mask_na = np.isnan(X) # first imputation - X = utils.linear_interpolation(X) - self.fit_parameters(X) + X_imp = self.init_imputation(X) + self._check_conditionning(X_imp) + + self.fit_parameters_with_missingness(X) if not np.any(mask_na): self.X = X + return + + X = self._maximize_likelihood(X_imp, mask_na) for iter_em in range(self.max_iter_em): X = self._sample_ou(X, mask_na) + self.combine_parameters() # Stop criteria self.update_criteria_stop(X) if self._check_convergence(): - print(f"EM converged after {iter_em} iterations.") + if self.verbose: + print(f"EM converged after {iter_em} iterations.") break self.dict_criteria_stop = {key: [] for key in self.dict_criteria_stop} @@ -359,23 +412,58 @@ def transform(self, X: NDArray) -> NDArray: Final array after EM sampling. """ mask_na = np.isnan(X) + X = X.copy() # shape_original = X.shape if hash(X.tobytes()) == self.hash_fit: X = self.X + warm_start = True else: X = utils.prepare_data(X, self.period) - X = utils.linear_interpolation(X) + X = self.init_imputation(X) + warm_start = False - if self.method == "mle": - X_transformed = self._maximize_likelihood(X, mask_na) - elif self.method == "sample": - X_transformed = self._sample_ou(X, mask_na, estimate_params=False) + if (self.method == "mle") or not warm_start: + X = self._maximize_likelihood(X, mask_na) + if self.method == "sample": + X = self._sample_ou(X, mask_na, estimate_params=False) - if np.all(np.isnan(X_transformed)): + if np.all(np.isnan(X)): raise AssertionError("Result contains NaN. This is a bug.") - return X_transformed + return X + + def _check_conditionning(self, X: NDArray): + """ + Check that the data matrix X is not ill-conditioned. Running the EM algorithm on data with + colinear columns leads to numerical instability and unconsistent results. + + Parameters + ---------- + X : NDArray + Data matrix + + Raises + ------ + IllConditioned + Data matrix is ill-conditioned due to colinear columns. + """ + n_rows, n_cols = X.shape + # if n_rows == 1 the function np.cov returns a float + if n_rows == 1: + min_sv = 0 + else: + cov = np.cov(X, bias=True, rowvar=False).reshape(n_cols, -1) + _, sv, _ = spl.svd(cov) + min_sv = min(np.sqrt(sv)) + if min_sv < self.min_std: + warnings.warn( + f"The covariance matrix is ill-conditioned, indicating high-colinearity: the " + f"smallest singular value of the data matrix is smaller than the threshold " + f"min_std ({min_sv} < {self.min_std}). Consider removing columns of decreasing " + f"the threshold." + ) + # raise IllConditioned(min_sv, self.min_std) class MultiNormalEM(EM): @@ -392,6 +480,8 @@ class MultiNormalEM(EM): n_iter_ou : int, optional Number of iterations for the Gibbs sampling method (+ noise addition), necessary for convergence, by default 50. + n_samples : int, optional + Number of data samples used to estimate the parameters of the distribution. Default, 10 ampli : float, optional Whether to sample the posterior (1) or to maximise likelihood (0), by default 1. @@ -420,6 +510,7 @@ def __init__( method: Literal["mle", "sample"] = "sample", max_iter_em: int = 200, n_iter_ou: int = 50, + n_samples: int = 10, ampli: float = 1, random_state: Union[None, int, np.random.RandomState] = None, dt: float = 2e-2, @@ -433,6 +524,7 @@ def __init__( method=method, max_iter_em=max_iter_em, n_iter_ou=n_iter_ou, + n_samples=n_samples, ampli=ampli, random_state=random_state, dt=dt, @@ -480,18 +572,28 @@ def gradient_X_loglik(self, X: NDArray) -> NDArray: grad_X = -(X - self.means) @ self.cov_inv return grad_X - def get_gamma(self) -> NDArray: + def get_gamma(self, n_cols: int) -> NDArray: """ - Normalisation matrix used to stabilize the sampling process + If the covariance matrix is not full-rank, defines the projection matrix keeping the + sampling process in the relevant subspace. + + Parameters + ---------- + n_cols : int + Number of variables in the data matrix Returns ------- NDArray Gamma matrix """ - # gamma = np.diag(np.diagonal(self.cov)) - gamma = self.cov + U, diag, Vt = spl.svd(self.cov) + diag_trunc = np.where(diag < self.min_std**2, 0, diag) + diag_trunc = np.where(diag_trunc == 0, 0, np.min(diag_trunc)) + + gamma = (U * diag_trunc) @ Vt # gamma = np.eye(len(self.cov)) + return gamma def update_criteria_stop(self, X: NDArray): @@ -554,6 +656,34 @@ def combine_parameters(self): self.cov = cov_intragroup + cov_intergroup self.cov_inv = np.linalg.pinv(self.cov) + def fit_parameters_with_missingness(self, X: NDArray): + """ + First estimation of the model parameters based on data with missing values. + + Parameters + ---------- + X : NDArray + Data matrix with missingness + """ + self.means = np.nanmean(X, axis=0) + self.cov = utils.nancov(X) + self.cov_inv = np.linalg.pinv(self.cov) + + def set_parameters(self, means: NDArray, cov: NDArray): + """ + Sets the model parameters from a user value. + + Parameters + ---------- + means : NDArray + Specified value for the mean vector + cov : NDArray + Specified value for the covariance matrix + """ + self.means = means + self.cov = cov + self.cov_inv = np.linalg.pinv(self.cov) + def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray: """ Get the argmax of a posterior distribution. @@ -561,7 +691,7 @@ def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray: Parameters ---------- X : NDArray - Input DataFrame. + Input DataFrame without missingness mask_na : NDArray Boolean dataframe indicating which coefficients should be resampled, and are therefore the variables of the optimization @@ -576,6 +706,22 @@ def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray: X_imputed = self.means + X_imputed return X_imputed + def init_imputation(self, X: NDArray) -> NDArray: + """ + First simple imputation before iterating. + + Parameters + ---------- + X : NDArray + Data matrix, with missing values + + Returns + ------- + NDArray + Imputed matrix + """ + return utils.impute_nans(X, method="median") + def _check_convergence(self) -> bool: """ Check if the EM algorithm has converged. Three criteria: @@ -597,13 +743,19 @@ def _check_convergence(self) -> bool: list_logliks = self.dict_criteria_stop["logliks"] n_iter = len(list_means) - if n_iter < 10: + if n_iter < 3: return False min_diff_means1 = min_diff_Linf(list_covs, n_steps=1) min_diff_covs1 = min_diff_Linf(list_means, n_steps=1) min_diff_reached = min_diff_means1 < self.tolerance and min_diff_covs1 < self.tolerance + if min_diff_reached: + return True + + if n_iter < 7: + return False + min_diff_means5 = min_diff_Linf(list_covs, n_steps=5) min_diff_covs5 = min_diff_Linf(list_means, n_steps=5) @@ -617,8 +769,7 @@ def _check_convergence(self) -> bool: max_loglik = (min_diff_loglik5_ord1 < self.stagnation_loglik) or ( min_diff_loglik5_ord2 < self.stagnation_loglik ) - - return min_diff_reached or min_diff_stable or max_loglik + return min_diff_stable or max_loglik class VARpEM(EM): @@ -760,17 +911,28 @@ def gradient_X_loglik(self, X: NDArray) -> NDArray: return grad_1 + grad_2 - def get_gamma(self) -> NDArray: + def get_gamma(self, n_cols: int) -> NDArray: """ - Normalisation matrix used to stabilize the sampling process + If the noise matrix is not full-rank, defines the projection matrix keeping the + sampling process in the relevant subspace. Rescales the process to avoid instabilities. + + Parameters + ---------- + n_cols : int + Number of variables in the data matrix Returns ------- NDArray Gamma matrix """ - # gamma = np.diagonal(self.S).reshape(1, -1) - gamma = self.S + U, diag, Vt = spl.svd(self.S) + diag_trunc = np.where(diag < self.min_std**2, 0, diag) + diag_trunc = np.where(diag_trunc == 0, 0, np.min(diag_trunc)) + + gamma = (U * diag_trunc) @ Vt + # gamma = np.eye(len(self.cov)) + return gamma def update_criteria_stop(self, X: NDArray): @@ -841,9 +1003,40 @@ def combine_parameters(self) -> None: stack_YY = np.stack(list_YY) self.YY = np.mean(stack_YY, axis=0) self.S = self.YY - self.ZY.T @ self.B - self.B.T @ self.ZY + self.B.T @ self.ZZ @ self.B - self.S[self.S < 1e-12] = 0 + self.S[np.abs(self.S) < 1e-12] = 0 self.S_inv = np.linalg.pinv(self.S, rcond=1e-10) + def set_parameters(self, B: NDArray, S: NDArray): + """ + Sets the model parameters from a user value. + + Parameters + ---------- + means : NDArray + Specified value for the autoregression matrix + S : NDArray + Specified value for the noise covariance matrix + """ + self.B = B + self.S = S + self.S_inv = np.linalg.pinv(self.S) + + def init_imputation(self, X: NDArray) -> NDArray: + """ + First simple imputation before iterating. + + Parameters + ---------- + X : NDArray + Data matrix, with missing values + + Returns + ------- + NDArray + Imputed matrix + """ + return utils.linear_interpolation(X) + def _check_convergence(self) -> bool: """ Check if the EM algorithm has converged. Three criteria: @@ -866,13 +1059,19 @@ def _check_convergence(self) -> bool: list_logliks = self.dict_criteria_stop["logliks"] n_iter = len(list_B) - if n_iter < 10: + if n_iter < 3: return False min_diff_B1 = min_diff_Linf(list_B, n_steps=1) min_diff_S1 = min_diff_Linf(list_S, n_steps=1) min_diff_reached = min_diff_B1 < self.tolerance and min_diff_S1 < self.tolerance + if min_diff_reached: + return True + + if n_iter < 7: + return False + min_diff_B5 = min_diff_Linf(list_B, n_steps=5) min_diff_S5 = min_diff_Linf(list_S, n_steps=5) min_diff_stable = ( @@ -884,5 +1083,4 @@ def _check_convergence(self) -> bool: max_loglik = (max_loglik5_ord1 < self.stagnation_loglik) or ( max_loglik5_ord2 < self.stagnation_loglik ) - - return min_diff_reached or min_diff_stable or max_loglik + return min_diff_stable or max_loglik diff --git a/qolmat/imputations/imputers.py b/qolmat/imputations/imputers.py index 25fc5a2c..96cd8778 100644 --- a/qolmat/imputations/imputers.py +++ b/qolmat/imputations/imputers.py @@ -201,12 +201,15 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: cols_with_nans = df.columns[df.isna().any()] - if self.columnwise: - df_imputed = df.copy() - for col in cols_with_nans: - df_imputed[col] = self._transform_allgroups(df[[col]], col=col) + if cols_with_nans.empty: + df_imputed = df else: - df_imputed = self._transform_allgroups(df) + if self.columnwise: + df_imputed = df.copy() + for col in cols_with_nans: + df_imputed[col] = self._transform_allgroups(df[[col]], col=col) + else: + df_imputed = self._transform_allgroups(df) if df_imputed.isna().any().any(): raise AssertionError("Result of imputation contains NaN!") @@ -1456,6 +1459,7 @@ class ImputerRegressor(_Imputer): - if `row` all non complete rows will be removed from the train dataset, and will not be used for the inferance, - if `column` all non complete columns will be ignored. + By default, `row` random_state : Union[None, int, np.random.RandomState], optional Controls the randomness of the fit_transform, by default None @@ -1484,7 +1488,7 @@ def __init__( imputer_params: Tuple[str, ...] = ("handler_nan",), groups: Tuple[str, ...] = (), estimator: Optional[BaseEstimator] = None, - handler_nan: str = "column", + handler_nan: str = "row", random_state: Union[None, int, np.random.RandomState] = None, ): super().__init__( @@ -1547,7 +1551,6 @@ def _fit_element( assert col == "__all__" cols_with_nans = df.columns[df.isna().any()] dict_estimators: Dict[str, BaseEstimator] = dict() - for col in cols_with_nans: # Selects only the valid values in the Train Set according to the chosen method X, y = self.get_Xy_valid(df, col) @@ -1604,6 +1607,8 @@ def _transform_element( # Selects only non-NaN values for the Test Set is_na = y.isna() + if not np.any(is_na): + continue X = X.loc[is_na] y_hat = self._predict_estimator(model, X) @@ -1720,7 +1725,13 @@ def _transform_element( Omega = ~np.isnan(D) # D = utils.linear_interpolation(D) - M, A = model.decompose(D, Omega) + means = np.nanmean(D, axis=0) + stds = np.nanstd(D, axis=0) + stds = np.where(stds, stds, 1) + D_scale = (D - means) / stds + M, A = model.decompose(D_scale, Omega) + M = M * stds + means + A = A * stds + means M_final = utils.get_shape_original(M, X.shape) A_final = utils.get_shape_original(A, X.shape) @@ -1823,7 +1834,9 @@ def get_model(self, **hyperparams) -> rpca_noisy.RpcaNoisy: model = rpca_noisy.RpcaNoisy(random_state=self._rng, verbose=self.verbose, **hyperparams) return model - def _fit_element(self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0) -> NDArray: + def _fit_element( + self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0 + ) -> Tuple[NDArray, NDArray, NDArray]: """ Fits the imputer on `df`, at the group and/or column level depending on self.groups and self.columnwise. @@ -1839,8 +1852,11 @@ def _fit_element(self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0) Returns ------- - NDArray - Returns the reduced decomposition basis + Tuple + A tuple made of: + - the reduced decomposition basis + - the estimated mean of the columns + - the estimated standard deviation of the columns Raises ------ @@ -1855,9 +1871,14 @@ def _fit_element(self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0) D = utils.prepare_data(X, self.period) Omega = ~np.isnan(D) # D = utils.linear_interpolation(D) - _, _, _, Q = model.decompose_with_basis(X, Omega) - return Q + means = np.nanmean(D, axis=0) + stds = np.nanstd(D, axis=0) + stds = np.where(stds, stds, 1) + D_scale = (D - means) / stds + _, _, _, Q = model.decompose_with_basis(D_scale, Omega) + + return Q, means, stds def _transform_element( self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0 @@ -1895,14 +1916,16 @@ def _transform_element( Omega = ~np.isnan(D) # D = utils.linear_interpolation(D) - Q = self._dict_fitting[col][ngroup] - M, A = model.decompose_on_basis(D, Omega, Q) + Q, means, stds = self._dict_fitting[col][ngroup] + + D_scale = (D - means) / stds + M, A = model.decompose_on_basis(D_scale, Omega, Q) + M = M * stds + means + A = A * stds + means M_final = utils.get_shape_original(M, X.shape) - A_final = utils.get_shape_original(A, X.shape) - X_imputed = M_final + A_final - df_imputed = pd.DataFrame(X_imputed, index=df.index, columns=df.columns) + df_imputed = pd.DataFrame(M_final, index=df.index, columns=df.columns) df_imputed = df.where(~df.isna(), df_imputed) return df_imputed @@ -2230,6 +2253,8 @@ def _transform_element( """ self._check_dataframe(df) + if df.notna().all().all(): + return df model = self._dict_fitting[col][ngroup] X = df.values.astype(float) diff --git a/qolmat/imputations/imputers_pytorch.py b/qolmat/imputations/imputers_pytorch.py index ed6cc198..c2ee8a4a 100644 --- a/qolmat/imputations/imputers_pytorch.py +++ b/qolmat/imputations/imputers_pytorch.py @@ -35,6 +35,7 @@ class ImputerRegressorPyTorch(ImputerRegressor): - if `row` all non complete rows will be removed from the train dataset, and will not be used for the inferance, - if `column`all non complete columns will be ignored. + By default, `row` epochs: int Number of epochs when fitting the autoencoder, by default 100 learning_rate: float @@ -47,7 +48,7 @@ def __init__( self, groups: Tuple[str, ...] = (), estimator: Optional[nn.Sequential] = None, - handler_nan: str = "column", + handler_nan: str = "row", epochs: int = 100, learning_rate: float = 0.001, loss_fn: Callable = nn.L1Loss(), diff --git a/qolmat/imputations/rpca/rpca_noisy.py b/qolmat/imputations/rpca/rpca_noisy.py index d7a1a06d..8f836d99 100644 --- a/qolmat/imputations/rpca/rpca_noisy.py +++ b/qolmat/imputations/rpca/rpca_noisy.py @@ -28,32 +28,6 @@ class RpcaNoisy(RPCA): Chen, Yuxin, et al. "Bridging convex and nonconvex optimization in robust PCA: Noise, outliers and missing data." The Annals of Statistics 49.5 (2021): 2948-2971. - - Parameters - ---------- - random_state : int, optional - The seed of the pseudo random number generator to use, for reproductibility. - rank: Optional[int] - (estimated) low-rank of the matrix D - mu: Optional[float] - initial stiffness parameter for the constraint on M, L and Q - tau: Optional[float] - penalizing parameter for the nuclear norm - lam: Optional[float] - penalizing parameter for the sparse matrix - list_periods: Optional[List[int]] - list of periods, linked to the Toeplitz matrices - list_etas: Optional[List[float]] - list of penalizing parameters for the corresponding period in list_periods - max_iterations: Optional[int] - stopping criteria, maximum number of iterations. By default, the value is set to 10_000 - tolerance: Optional[float] - stoppign critera, minimum difference between 2 consecutive iterations. By default, - the value is set to 1e-6 - norm: Optional[str] - error norm, can be "L1" or "L2". By default, the value is set to "L2" - verbose: Optional[bool] - verbosity level, if False the warnings are silenced """ def __init__( @@ -70,6 +44,33 @@ def __init__( norm: str = "L2", verbose: bool = True, ) -> None: + """ + Parameters + ---------- + random_state : int, optional + The seed of the pseudo random number generator to use, for reproductibility. + rank: Optional[int] + Upper bound of the rank to be estimated + mu: Optional[float] + initial stiffness parameter for the constraint M = L Q + tau: Optional[float] + penalizing parameter for the nuclear norm + lam: Optional[float] + penalizing parameter for the sparse matrix + list_periods: Optional[List[int]] + list of periods, linked to the Toeplitz matrices + list_etas: Optional[List[float]] + list of penalizing parameters for the corresponding period in list_periods + max_iterations: Optional[int] + stopping criteria, maximum number of iterations. By default, the value is set to 10_000 + tolerance: Optional[float] + stoppign critera, minimum difference between 2 consecutive iterations. By default, + the value is set to 1e-6 + norm: Optional[str] + error norm, can be "L1" or "L2". By default, the value is set to "L2" + verbose: Optional[bool] + verbosity level, if False the warnings are silenced + """ super().__init__(max_iterations=max_iterations, tolerance=tolerance, verbose=verbose) self.rng = sku.check_random_state(random_state) self.rank = rank @@ -101,7 +102,6 @@ def get_params_scale(self, D: NDArray) -> Dict[str, float]: Regularization parameter for the L1 norm. """ - D = utils.linear_interpolation(D) rank = rpca_utils.approx_rank(D) tau = 1.0 / np.sqrt(max(D.shape)) lam = tau @@ -136,7 +136,8 @@ def decompose_with_basis( self, D: NDArray, Omega: NDArray ) -> Tuple[NDArray, NDArray, NDArray, NDArray]: """ - Compute the noisy RPCA with L1 or L2 time penalisation + Compute the noisy RPCA with L1 or L2 time penalisation, and returns the decomposition of + the low-rank matrix. Parameters ---------- @@ -156,7 +157,7 @@ def decompose_with_basis( Q: NDArray Reduced basis of the low-rank matrix """ - + D = utils.linear_interpolation(D) self.params_scale = self.get_params_scale(D) if self.lam is not None: @@ -178,13 +179,6 @@ def decompose_with_basis( "The periods provided in argument in `list_periods` must smaller " f"than the number of rows in the matrix but {period} >= {n_rows}!" ) - # if (n_rows == 1) or (n_cols == 1): - # warnings.warn( - # f"RPCA algorithm may provide bad results. Function {function_str} increased from" - # f" {cost_start} to {cost_end} instead of decreasing!".format("%.2f") - # ) - - D = utils.linear_interpolation(D) M, A, L, Q = self.minimise_loss( D, @@ -219,7 +213,11 @@ def minimise_loss( norm: str = "L2", ) -> Tuple: """ - Compute the noisy RPCA with a L2 time penalisation + Compute the noisy RPCA with a L2 time penalisation. + + This function computes the noisy Robust Principal Component Analysis (RPCA) using a L2 time + penalisation. It iteratively minimizes a loss function to separate the low-rank and sparse + components from the input data matrix. Parameters ---------- @@ -227,40 +225,49 @@ def minimise_loss( Observations matrix of shape (m, n). Omega : np.ndarray Binary matrix indicating the observed entries of D, shape (m, n). - rank: Optional[int] - (estimated) low-rank of the matrix D - tau: Optional[float] - penalizing parameter for the nuclear norm - lam: Optional[float] - penalizing parameter for the sparse matrix - mu: Optional[float] - initial stiffness parameter for the constraint on M, L and Q - list_periods: Optional[List[int]] - list of periods, linked to the Toeplitz matrices - list_etas: Optional[List[float]] - list of penalizing parameters for the corresponding period in list_periods - max_iterations: Optional[int] - stopping criteria, maximum number of iterations. By default, the value is set to 10_000 - tolerance: Optional[float] - stoppign critera, minimum difference between 2 consecutive iterations. By default, - the value is set to 1e-6 - norm: Optional[str] - error norm, can be "L1" or "L2". By default, the value is set to "L2" + rank : int + Estimated low-rank of the matrix D. + tau : float + Penalizing parameter for the nuclear norm. + lam : float + Penalizing parameter for the sparse matrix. + mu : float, optional + Initial stiffness parameter for the constraint on M, L, and Q. Defaults + to 1e-2. + list_periods : List[int], optional + List of periods linked to the Toeplitz matrices. Defaults to []. + list_etas : List[float], optional + List of penalizing parameters for the corresponding periods in list_periods. Defaults + to []. + max_iterations : int, optional + Stopping criteria, maximum number of iterations. Defaults to 10000. + tolerance : float, optional + Stopping criteria, minimum difference between 2 consecutive iterations. + Defaults to 1e-6. + norm : str, optional + Error norm, can be "L1" or "L2". Defaults to "L2". Returns ------- - M : np.ndarray - Low-rank signal matrix of shape (m, n). - A : np.ndarray - Anomalies matrix of shape (m, n). - L : np.ndarray - Basis Unitary array of shape (m, rank). - Q : np.ndarray - Basis Unitary array of shape (rank, n). + Tuple + A tuple containing the following elements: + - M : np.ndarray + Low-rank signal matrix of shape (m, n). + - A : np.ndarray + Anomalies matrix of shape (m, n). + - L : np.ndarray + Basis unitary array of shape (m, rank). + - Q : np.ndarray + Basis unitary array of shape (rank, n). + + Raises + ------ + ValueError + If the periods provided in the argument in `list_periods` are not + smaller than the number of rows in the matrix. """ - print("minimise_loss") rho = 1.1 n_rows, n_cols = D.shape @@ -316,7 +323,6 @@ def minimise_loss( A_Omega = rpca_utils.soft_thresholding(D - M, lam) A_Omega_C = D - M A = np.where(Omega, A_Omega, A_Omega_C) - Q = scp.linalg.solve( a=tau * Ir + mu * (L.T @ L), b=L.T @ (mu * M + Y), @@ -360,6 +366,27 @@ def decompose_on_basis( Omega: NDArray, Q: NDArray, ) -> Tuple[NDArray, NDArray]: + """ + Decompose the matrix D with an observation matrix Omega using the noisy RPCA algorithm, + with a fixed reduced basis given by the matrix Q. This allows to impute new data without + resolving the optimization problem on the whole dataset. + + Parameters + ---------- + D : NDArray + _description_ + Omega : NDArray + _description_ + Q : NDArray + _description_ + + Returns + ------- + Tuple[NDArray, NDArray] + A tuple representing the decomposition of D with: + - M: low-rank matrix + - A: sparse matrix + """ D = utils.linear_interpolation(D) params_scale = self.get_params_scale(D) @@ -402,23 +429,24 @@ def decompose_on_basis( def _check_cost_function_minimized( self, - observations: NDArray, - low_rank: NDArray, - anomalies: NDArray, + D: NDArray, + M: NDArray, + A: NDArray, Omega: NDArray, tau: float, lam: float, ): - """Check that the functional minimized by the RPCA - is smaller at the end than at the beginning + """ + Check that the functional minimized by the RPCA is smaller at the end than at the + beginning. Parameters ---------- - observations : NDArray + D : NDArray observations matrix with first linear interpolation - low_rank : NDArray + M : NDArray low_rank matrix resulting from RPCA - anomalies : NDArray + A : NDArray sparse matrix resulting from RPCA Omega: NDArrau boolean matrix indicating the observed values @@ -428,9 +456,9 @@ def _check_cost_function_minimized( parameter penalizing the L1-norm of the anomaly/sparse part """ cost_start = self.cost_function( - observations, - observations, - np.full_like(observations, 0), + D, + D, + np.full_like(D, 0), Omega, tau, lam, @@ -439,9 +467,9 @@ def _check_cost_function_minimized( norm=self.norm, ) cost_end = self.cost_function( - observations, - low_rank, - anomalies, + D, + M, + A, Omega, tau, lam, @@ -449,12 +477,12 @@ def _check_cost_function_minimized( self.list_etas, norm=self.norm, ) - function_str = "1/2 $ ||D-M-A||_2 + tau ||D||_* + lam ||A||_1" + function_str = "1/2 ||D-M-A||_2 + tau ||D||_* + lam ||A||_1" if len(self.list_etas) > 0: for eta in self.list_etas: function_str += f"{eta} ||MH||_{self.norm}" - if self.verbose and (round(cost_start, 4) - round(cost_end, 4)) <= -1e-2: + if self.verbose and (cost_end > cost_start * (1 + 1e-6)): warnings.warn( f"RPCA algorithm may provide bad results. Function {function_str} increased from" f" {cost_start} to {cost_end} instead of decreasing!".format("%.2f") @@ -462,9 +490,9 @@ def _check_cost_function_minimized( @staticmethod def cost_function( - observations: NDArray, - low_rank: NDArray, - anomalies: NDArray, + D: NDArray, + M: NDArray, + A: NDArray, Omega: NDArray, tau: float, lam: float, @@ -473,15 +501,15 @@ def cost_function( norm: str = "L2", ): """ - Compute cost function for different RPCA algorithm + Estimated cost function for the noisy RPCA algorithm Parameters ---------- - observations : NDArray + D : NDArray Matrix of observations - low_rank : NDArray + M : NDArray Low-rank signal - anomalies : NDArray + A : NDArray Anomalies Omega : NDArray Mask for observations @@ -506,20 +534,17 @@ def cost_function( temporal_norm: float = 0 if len(list_etas) > 0: # matrices for temporal correlation - list_H = [ - rpca_utils.toeplitz_matrix(period, observations.shape[0]) - for period in list_periods - ] + list_H = [rpca_utils.toeplitz_matrix(period, D.shape[0]) for period in list_periods] if norm == "L1": for eta, H_matrix in zip(list_etas, list_H): - temporal_norm += eta * np.sum(np.abs(H_matrix @ low_rank)) + temporal_norm += eta * np.sum(np.abs(H_matrix @ M)) elif norm == "L2": for eta, H_matrix in zip(list_etas, list_H): - temporal_norm += eta * float(np.linalg.norm(H_matrix @ low_rank, "fro")) - anomalies_norm = np.sum(np.abs(anomalies * Omega)) + temporal_norm += eta * float(np.linalg.norm(H_matrix @ M, "fro")) + anomalies_norm = np.sum(np.abs(A * Omega)) cost = ( - 1 / 2 * ((Omega * (observations - low_rank - anomalies)) ** 2).sum() - + tau * np.linalg.norm(low_rank, "nuc") + 1 / 2 * ((Omega * (D - M - A)) ** 2).sum() + + tau * np.linalg.norm(M, "nuc") + lam * anomalies_norm + temporal_norm ) diff --git a/qolmat/imputations/rpca/rpca_pcp.py b/qolmat/imputations/rpca/rpca_pcp.py index 67dde3cb..f3b8e751 100644 --- a/qolmat/imputations/rpca/rpca_pcp.py +++ b/qolmat/imputations/rpca/rpca_pcp.py @@ -75,8 +75,7 @@ def get_params_scale(self, D: NDArray): Regularization parameter for the L1 norm. """ - D = utils.linear_interpolation(D) - mu = D.size / (4.0 * rpca_utils.l1_norm(D)) + mu = min(1e3, D.size / (4.0 * rpca_utils.l1_norm(D))) lam = 1 / np.sqrt(np.max(D.shape)) dict_params = {"mu": mu, "lam": lam} return dict_params @@ -100,13 +99,14 @@ def decompose(self, D: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]: A: NDArray Anomalies """ + D = utils.linear_interpolation(D) + if np.all(D == 0): + return D, D params_scale = self.get_params_scale(D) mu = params_scale["mu"] if self.mu is None else self.mu lam = params_scale["lam"] if self.lam is None else self.lam - D = utils.linear_interpolation(D) - D_norm = np.linalg.norm(D, "fro") A = np.array(np.full_like(D, 0)) diff --git a/qolmat/imputations/rpca/rpca_utils.py b/qolmat/imputations/rpca/rpca_utils.py index 592d97ce..9e6c8945 100644 --- a/qolmat/imputations/rpca/rpca_utils.py +++ b/qolmat/imputations/rpca/rpca_utils.py @@ -29,6 +29,8 @@ def approx_rank( int: Approximated rank of M """ + if np.all(M == 0): + return 1 if threshold == 1: return min(M.shape) _, values_singular, _ = np.linalg.svd(M, full_matrices=False) @@ -80,7 +82,6 @@ def svd_thresholding(X: NDArray, threshold: float) -> NDArray: V is the array of the right singular vectors of X s is the array of the singular values as a diagonal array """ - U, s, Vh = np.linalg.svd(X, full_matrices=False) s = soft_thresholding(s, threshold) return U @ (np.diag(s) @ Vh) diff --git a/qolmat/utils/data.py b/qolmat/utils/data.py index e8678a10..2edd7c7f 100644 --- a/qolmat/utils/data.py +++ b/qolmat/utils/data.py @@ -100,7 +100,19 @@ def get_data( url_zenodo = "https://zenodo.org/record/" if name_data == "Beijing": df = read_csv_local("beijing") - df = df.set_index(["station", "date"]) + df["date"] = pd.to_datetime(df["date"]) + + # df["date"] = pd.to_datetime( + # { + # "year": df["year"], + # "month": df["month"], + # "day": df["day"], + # "hour": df["hour"], + # } + # ) + df = df.drop(columns=["year", "month", "day", "hour", "wd"]) + # df = df.set_index(["station", "date"]) + df = df.groupby(["station", "date"]).mean() return df if name_data == "Superconductor": df = read_csv_local("conductors") @@ -173,7 +185,8 @@ def get_data( return df elif name_data == "Monach_electricity_australia": urllink = os.path.join( - url_zenodo, "4659727/files/australian_electricity_demand_dataset.zip?download=1" + url_zenodo, + "4659727/files/australian_electricity_demand_dataset.zip?download=1", ) zipname = "australian_electricity_demand_dataset" list_loaded_data = download_data_from_zip(zipname, urllink, datapath=datapath) @@ -216,7 +229,8 @@ def preprocess_data_beijing(df: pd.DataFrame) -> pd.DataFrame: df["station"] = "Beijing" df.set_index(["station", "datetime"], inplace=True) df.drop( - columns=["year", "month", "day", "hour", "No", "cbwd", "Iws", "Is", "Ir"], inplace=True + columns=["year", "month", "day", "hour", "No", "cbwd", "Iws", "Is", "Ir"], + inplace=True, ) df.sort_index(inplace=True) df = df.groupby( diff --git a/qolmat/utils/exceptions.py b/qolmat/utils/exceptions.py index 5494ede6..eb00da95 100644 --- a/qolmat/utils/exceptions.py +++ b/qolmat/utils/exceptions.py @@ -56,3 +56,12 @@ def __init__(self): class SingleSample(Exception): def __init__(self): super().__init__("""This imputer cannot be fitted on a single sample!""") + + +class IllConditioned(Exception): + def __init__(self, min_sv: float, min_std: float): + super().__init__( + f"The covariance matrix is ill-conditioned, indicating high-colinearity: the smallest " + f"singular value of the data matrix is smaller than the threshold min_std ({min_sv} < " + f"{min_std}). Consider removing columns of decreasing the threshold." + ) diff --git a/qolmat/utils/plot.py b/qolmat/utils/plot.py index d37d3f46..c6700e13 100644 --- a/qolmat/utils/plot.py +++ b/qolmat/utils/plot.py @@ -156,8 +156,9 @@ def plot_images( def make_ellipses( - x: NDArray, - y: NDArray, + mean_x: float, + mean_y: float, + cov: NDArray, ax: mpl.axes.Axes, n_std: float = 2, color: Union[str, Any, Tuple[float, float, float]] = "None", @@ -167,9 +168,12 @@ def make_ellipses( Parameters ---------- - x, y : array-like, shape (n, ) - Input data. - + mean_x : float + Abscisse of the ellipse center + mean_y : float + Ordinate of the ellipse center + cov : NDArray + Covariance matrix defining the ellipse ax : matplotlib.axes.Axes The axes object to draw the ellipse into. @@ -183,18 +187,13 @@ def make_ellipses( ------- matplotlib.patches.Ellipse """ - if x.size != y.size: - raise ValueError("x and y must be the same size") - cov = np.cov(x, y) pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1]) ell_radius_x = np.sqrt(1 + pearson) * 2.5 ell_radius_y = np.sqrt(1 - pearson) * 2.5 ell = mpl.patches.Ellipse((0, 0), width=ell_radius_x, height=ell_radius_y, facecolor=color) scale_x = np.sqrt(cov[0, 0]) * n_std - mean_x = np.mean(x) scale_y = np.sqrt(cov[1, 1]) * n_std - mean_y = np.mean(y) transf = ( mpl.transforms.Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y) ) @@ -205,6 +204,43 @@ def make_ellipses( ax.set_aspect("equal", "datalim") +def make_ellipses_from_data( + x: NDArray, + y: NDArray, + ax: mpl.axes.Axes, + n_std: float = 2, + color: Union[str, Any, Tuple[float, float, float]] = "None", +): + """ + Create a plot of the covariance confidence ellipse of *x* and *y*. + + Parameters + ---------- + x, y : array-like, shape (n, ) + Input data. + + ax : matplotlib.axes.Axes + The axes object to draw the ellipse into. + + n_std : float + The number of standard deviations to determine the ellipse's radiuses. + + color : Optional[str] + facecolor + + Returns + ------- + matplotlib.patches.Ellipse + """ + if x.size != y.size: + raise ValueError("x and y must be the same size") + + cov = np.cov(x, y) + mean_x = np.mean(x) + mean_y = np.mean(y) + make_ellipses(mean_x, mean_y, cov, ax, n_std, color) + + def compare_covariances( df_1: pd.DataFrame, df_2: pd.DataFrame, @@ -235,9 +271,17 @@ def compare_covariances( if color is None: color = tab10(0) ax.scatter(df2[col_x], df2[col_y], marker=".", color=color, s=2, alpha=0.7, label="imputed") - ax.scatter(df1[col_x], df1[col_y], marker=".", color="black", s=2, alpha=0.7, label="original") - make_ellipses(df1[col_x], df1[col_y], ax, color="black") - make_ellipses(df2[col_x], df2[col_y], ax, color=color) + ax.scatter( + df1[col_x], + df1[col_y], + marker=".", + color="black", + s=2, + alpha=0.7, + label="original", + ) + make_ellipses_from_data(df1[col_x], df1[col_y], ax, color="black") + make_ellipses_from_data(df2[col_x], df2[col_y], ax, color=color) ax.set_xlabel(col_x) ax.set_ylabel(col_y) @@ -297,7 +341,7 @@ def multibar( color=color_col, ) plt.xticks(x, df.index) - ax.bar_label(rect, padding=3, fmt=f"%.{decimals}f") + ax.bar_label(rect, padding=3, fmt=f"%.{decimals}g") plt.legend(loc=(1, 0)) diff --git a/qolmat/utils/utils.py b/qolmat/utils/utils.py index 7886a161..f1785c75 100644 --- a/qolmat/utils/utils.py +++ b/qolmat/utils/utils.py @@ -215,3 +215,18 @@ def create_lag_matrices(X: NDArray, p: int) -> Tuple[NDArray, NDArray]: Z = np.concatenate(list_X_lag, axis=1) Y = X[-n_rows_new:, :] return Z, Y + + +def nancov(X: NDArray) -> NDArray: + _, n_cols = X.shape + cov = np.nan * np.zeros((n_cols, n_cols)) + mask = np.isnan(X) + for i in range(n_cols): + Di = X[:, i] - np.nanmean(X[:, i]) + for j in range(n_cols): + select = (~mask[:, i]) & (~mask[:, j]) + Di = X[select, i] - np.mean(X[select, i]) + Dj = X[select, j] - np.mean(X[select, j]) + cov[i, j] = np.nanmean(Di * Dj) + cov = impute_nans(cov, method="zeros") + return cov diff --git a/tests/imputations/test_em_sampler.py b/tests/imputations/test_em_sampler.py index d3ab1cf0..dfc01d5a 100644 --- a/tests/imputations/test_em_sampler.py +++ b/tests/imputations/test_em_sampler.py @@ -3,20 +3,21 @@ import pytest from numpy.typing import NDArray from scipy import linalg +import scipy from sklearn.datasets import make_spd_matrix +from qolmat.utils import utils from qolmat.imputations import em_sampler +from qolmat.utils.exceptions import IllConditioned np.random.seed(42) A: NDArray = np.array([[3, 1, 0], [1, 1, 0], [0, 0, 1]], dtype=float) A_inverse: NDArray = np.array([[0.5, -0.5, 0], [-0.5, 1.5, 0], [0, 0, 1]], dtype=float) X_missing = np.array( - [[1, np.nan, 1], [1, np.nan, 3], [1, 4, np.nan], [1, 2, 1], [1, 1, np.nan]], dtype=float -) -X_first_guess: NDArray = np.array( - [[1, 4, 1], [1, 4, 3], [1, 4, 4], [1, 2, 1], [1, 1, 4]], dtype=float + [[1, np.nan, 1], [2, np.nan, 3], [1, 4, np.nan], [-1, 2, 1], [1, 1, np.nan]], + dtype=float, ) mask: NDArray = np.isnan(X_missing) @@ -32,7 +33,9 @@ def generate_multinormal_predefined_mean_cov(d=3, n=500): mask = np.array(np.full_like(X, False), dtype=bool) for j in range(X.shape[1]): ind = rng.choice( - np.arange(X.shape[0]), size=np.int64(np.ceil(X.shape[0] * 0.1)), replace=False + np.arange(X.shape[0]), + size=np.int64(np.ceil(X.shape[0] * 0.1)), + replace=False, ) mask[ind, j] = True X_missing = X.copy() @@ -69,7 +72,9 @@ def generate_varp_process(d=3, n=10000, p=1): mask = np.array(np.full_like(X, False), dtype=bool) for j in range(X.shape[1]): ind = rng.choice( - np.arange(X.shape[0]), size=np.int64(np.ceil(X.shape[0] * 0.1)), replace=False + np.arange(X.shape[0]), + size=np.int64(np.ceil(X.shape[0] * 0.1)), + replace=False, ) mask[ind, j] = True X_missing = X.copy() @@ -78,21 +83,22 @@ def generate_varp_process(d=3, n=10000, p=1): @pytest.mark.parametrize( - "A, X_first_guess, mask", - [(A, X_first_guess, mask)], + "A, mask", + [(A, mask)], ) def test_gradient_conjugue( A: NDArray, - X_first_guess: NDArray, mask: NDArray, ) -> None: """Test the conjugate gradient algorithm.""" + X_first_guess = utils.impute_nans(X_missing) X_result = em_sampler._conjugate_gradient(A, X_first_guess, mask) - X_expected = np.array([[1, -1, 1], [1, -1, 3], [1, 4, 0], [1, 2, 1], [1, 1, 0]], dtype=float) + X_expected = np.array([[1, -1, 1], [2, -2, 3], [1, 4, 0], [-1, 2, 1], [1, 1, 0]], dtype=float) - np.testing.assert_allclose(X_result, X_expected, atol=1e-5) assert np.sum(X_result * (X_result @ A)) <= np.sum(X_first_guess * (X_first_guess @ A)) - assert np.allclose(X_first_guess[~mask], X_result[~mask]) + assert np.allclose(X_missing[~mask], X_result[~mask]) + assert ((X_result @ A)[mask] == 0).all() + np.testing.assert_allclose(X_result, X_expected, atol=1e-5) def test_get_lag_p(): @@ -136,9 +142,9 @@ def test_fit_calls(mocker, X_missing: NDArray) -> None: em = em_sampler.MultiNormalEM(max_iter_em=max_iter_em) em.fit(X_missing) assert mock_sample_ou.call_count == max_iter_em - assert mock_maximize_likelihood.call_count == 0 + assert mock_maximize_likelihood.call_count == 1 assert mock_check_convergence.call_count == max_iter_em - assert mock_fit_parameters.call_count == 1 + assert mock_fit_parameters.call_count == 0 assert mock_combine_parameters.call_count == max_iter_em assert mock_update_criteria_stop.call_count == max_iter_em @@ -191,7 +197,48 @@ def test_em_sampler_check_convergence_false( em.dict_criteria_stop["means"] = means em.dict_criteria_stop["covs"] = covs em.dict_criteria_stop["logliks"] = logliks - assert em._check_convergence() == False + assert em._check_convergence() == True + + +@pytest.mark.parametrize( + "model", + [ + em_sampler.MultiNormalEM(method="sample", n_iter_ou=512, dt=1e-2), + em_sampler.VARpEM(method="sample", n_iter_ou=512, dt=1e-2, p=0), + ], +) +def test_sample_ou_2d(model): + # model = em_sampler.MultiNormalEM(method="sample", n_iter_ou=512, dt=1e-2) + means = np.array([5, -2]) + cov = np.array([[1, -0.5], [-0.5, 2]]) + if isinstance(model, em_sampler.VARpEM): + model.set_parameters(means.reshape(1, -1), cov) + else: + model.set_parameters(means, cov) + n_samples = 10000 + x1 = 4 + D = x1 * np.ones((n_samples, 2)) + D[:, 0] = np.nan + values = model.transform(D)[:, 0] + mean_theo = means[0] + cov[0, 1] / cov[1, 1] * (x1 - means[1]) + var_theo = cov[0, 0] - cov[0, 1] ** 2 / cov[1, 1] + mean_est = np.mean(values) + var_est = np.var(values) + alpha = 0.01 + q_alpha = scipy.stats.norm.ppf(1 - alpha / 2) + + print(mean_est, "vs", mean_theo) + assert abs(mean_est - mean_theo) < np.sqrt(var_theo / n_samples) * q_alpha + + ratio_inf = scipy.stats.chi2.ppf(alpha / 2, n_samples) / (n_samples - 1) + ratio_sup = scipy.stats.chi2.ppf(1 - alpha / 2, n_samples) / (n_samples - 1) + + ratio = var_est / var_theo + + print(var_est, "vs", var_theo) + print(ratio_inf, "<", ratio, "<", ratio_sup) + assert ratio_inf <= ratio + assert ratio <= ratio_sup @pytest.mark.parametrize( @@ -231,12 +278,23 @@ def test_varem_sampler_check_convergence_false( em.dict_criteria_stop["B"] = list_B em.dict_criteria_stop["S"] = list_S em.dict_criteria_stop["logliks"] = logliks - assert em._check_convergence() == False + assert em._check_convergence() == True + + +def test_illconditioned_multinormalem() -> None: + """Test that data with colinearity raises an exception.""" + X = np.array([[1, np.nan, 8, 1], [3, 1, 4, 2], [2, 3, np.nan, 1]], dtype=float) + model = em_sampler.MultiNormalEM() + with pytest.warns(UserWarning): + _ = model.fit_transform(X) + # except IllConditioned: + # return + # assert False def test_no_more_nan_multinormalem() -> None: """Test there are no more missing values after the MultiNormalEM algorithm.""" - X = np.array([[1, np.nan, 8, 1], [3, 1, 4, 2], [2, 3, np.nan, 1]], dtype=float) + X = np.array([[1, np.nan], [3, 1], [np.nan, 3]], dtype=float) model = em_sampler.MultiNormalEM() X_imp = model.fit_transform(X) assert np.sum(np.isnan(X)) > 0 @@ -297,7 +355,7 @@ def test_multinormal_em_minimize_llik(): @pytest.mark.parametrize("method", ["sample", "mle"]) def test_multinormal_em_fit_transform(method: Literal["mle", "sample"]): imputer = em_sampler.MultiNormalEM(method=method, random_state=11) - X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]]) + X = X_missing.copy() result = imputer.fit_transform(X) assert result.shape == X.shape np.testing.assert_allclose(result[~np.isnan(X)], X[~np.isnan(X)]) @@ -331,25 +389,22 @@ def test_parameters_after_imputation_varpem(p: int): def test_varpem_fit_transform(): - imputer = em_sampler.VARpEM(method="sample", random_state=11) + imputer = em_sampler.VARpEM(method="mle", random_state=11) X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]]) result = imputer.fit_transform(X) - expected = np.array( - [ - [1.0, 1.0, 1.0, 1.0], - [1.0, 1.5, 3.0, 2.0], - [1.0, 2.0, 2.0, 1.0], - [2.0, 2.0, 2.0, 2.0], - ] - ) - np.testing.assert_allclose(result, expected, atol=1e-12) + assert result.shape == X.shape + np.testing.assert_allclose(result[~np.isnan(X)], X[~np.isnan(X)]) + assert not np.any(np.isnan(result)) @pytest.mark.parametrize( - "X, em, p", - [(X_first_guess, em_sampler.MultiNormalEM(), 0), (X_first_guess, em_sampler.VARpEM(p=2), 2)], + "em, p", + [ + (em_sampler.MultiNormalEM(), 0), + (em_sampler.VARpEM(p=2), 2), + ], ) -def test_gradient_X_loglik(X: NDArray, em: em_sampler.EM, p: int): +def test_gradient_X_loglik(em: em_sampler.EM, p: int): d = 3 X, _, _, _ = generate_varp_process(d=d, n=10, p=p) em.fit_parameters(X) diff --git a/tests/imputations/test_imputers.py b/tests/imputations/test_imputers.py index cab26a9c..20f6f39b 100644 --- a/tests/imputations/test_imputers.py +++ b/tests/imputations/test_imputers.py @@ -263,41 +263,10 @@ def test_ImputerRegressor_fit_transform(df: pd.DataFrame) -> None: @pytest.mark.parametrize("df", [df_timeseries]) def test_ImputerRpcaNoisy_fit_transform(df: pd.DataFrame) -> None: imputer = imputers.ImputerRpcaNoisy(columnwise=False, max_iterations=100, tau=1, lam=0.3) - imputer = imputer.fit(df) - result = imputer.transform(df) - expected = pd.DataFrame( - { - "col1": [i for i in range(20)], - "col2": [0, 1, 2, 2, 2] + [i for i in range(5, 20)], - } - ) - result = np.around(result) - np.testing.assert_allclose(result, expected, atol=1e-2) - - result = imputer.transform(df.iloc[:10]) - expected = pd.DataFrame( - { - "col1": [i for i in range(10)], - "col2": [0, 1, 2, 2, 2] + [i for i in range(5, 10)], - } - ) - result = np.around(result) - np.testing.assert_allclose(result, expected, atol=1e-2) - - -# @pytest.mark.parametrize("df", [df_incomplete]) -# def test_ImputerSoftImpute_fit_transform(df: pd.DataFrame) -> None: -# imputer = imputers.ImputerSoftImpute( -# columnwise=False, max_iterations=100, tau=0.3, random_state=4 -# ) -# result = imputer.fit_transform(df) -# expected = pd.DataFrame( -# { -# "col1": [0, 1.327, 2, 3, 0.137], -# "col2": [-1, 0.099, 0.5, 0.122, 1.5], -# } -# ) -# np.testing.assert_allclose(result, expected, atol=1e-2) + df_omega = df.notna() + df_result = imputer.fit_transform(df) + np.testing.assert_allclose(df_result[df_omega], df[df_omega]) + assert df_result.notna().all().all() index_grouped = pd.MultiIndex.from_product([["a", "b"], range(4)], names=["group", "date"]) @@ -322,7 +291,7 @@ def test_ImputerRpcaNoisy_fit_transform(df: pd.DataFrame) -> None: imputers.ImputerRpcaPcp(groups=("group",)), imputers.ImputerRpcaNoisy(groups=("group",)), imputers.ImputerSoftImpute(groups=("group",)), - imputers.ImputerEM(groups=("group",)), + imputers.ImputerEM(groups=("group",), method="mle"), ] @@ -347,9 +316,9 @@ def test_models_fit_transform_grouped(imputer): imputers.ImputerResiduals(period=2), imputers.KNNImputer(), imputers.ImputerMICE(), - imputers.ImputerRegressor(), - imputers.ImputerRpcaNoisy(tau=0, lam=0), - imputers.ImputerRpcaPcp(lam=0), + imputers.ImputerRegressor(estimator=LinearRegression()), + imputers.ImputerRpcaNoisy(tau=1, lam=1), + imputers.ImputerRpcaPcp(lam=1), imputers.ImputerSoftImpute(), imputers.ImputerEM(), ] diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 41f69fac..0e08a3a7 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -36,7 +36,9 @@ names=["station", "datetime"], ) df_preprocess_beijing = pd.DataFrame( - [[1, 2], [3, np.nan], [np.nan, 6]], columns=["a", "b"], index=index_preprocess_beijing + [[1, 2], [3, np.nan], [np.nan, 6]], + columns=["a", "b"], + index=index_preprocess_beijing, ) columns = ["mean_atomic_mass", "wtd_mean_atomic_mass"] @@ -113,7 +115,9 @@ names=["station", "datetime"], ) df_preprocess_offline = pd.DataFrame( - [[1, 2], [3, np.nan], [np.nan, 6]], columns=["a", "b"], index=index_preprocess_offline + [[1, 2], [3, np.nan], [np.nan, 6]], + columns=["a", "b"], + index=index_preprocess_offline, ) @@ -167,7 +171,7 @@ def test_utils_data_get_data(name_data: str, df: pd.DataFrame, mocker: MockerFix if name_data == "Beijing": assert mock_download.call_count == 0 assert mock_read.call_count == 1 - pd.testing.assert_frame_equal(df_result, df.set_index(["station", "date"])) + assert df_result.index.names == ["station", "date"] elif name_data == "Superconductor": assert mock_download.call_count == 0 assert mock_read.call_count == 1 @@ -213,8 +217,6 @@ def test_utils_data_get_data_corrupted( ) -> None: mock_get = mocker.patch("qolmat.utils.data.get_data", return_value=df) df_out = data.get_data_corrupted(name_data) - print(df_out) - print(df) assert mock_get.call_count == 1 assert df_out.shape == df.shape pd.testing.assert_index_equal(df_out.index, df.index) diff --git a/tests/utils/test_plot.py b/tests/utils/test_plot.py index cf891d01..5c45e72e 100644 --- a/tests/utils/test_plot.py +++ b/tests/utils/test_plot.py @@ -72,10 +72,10 @@ def test__utils_plot_plot_images( @pytest.mark.parametrize("X", [X]) -def test_utils_plot_make_ellipses(X: np.ndarray, mocker: MockerFixture): +def test_utils_plot_make_ellipses_from_data(X: np.ndarray, mocker: MockerFixture): mocker.patch("matplotlib.pyplot.show") ax = plt.gca() - plot.make_ellipses(X[1], X[2], ax, color="blue") + plot.make_ellipses_from_data(X[1], X[2], ax, color="blue") assert len(plt.gcf().get_axes()) > 0 plt.close("all")