From d5caf2388c9a3dc72525f36667ca2955c0d0efbd Mon Sep 17 00:00:00 2001 From: Julien Roussel <3178729-JulienRoussel77@users.noreply.gitlab.com> Date: Sun, 14 Apr 2024 19:16:41 +0200 Subject: [PATCH 1/3] metrics updated --- HISTORY.rst | 1 + docs/imputers.rst | 30 +- examples/benchmark.md | 8 +- examples/tutorials/plot_tuto_categorical.py | 2 +- qolmat/benchmark/metrics.py | 291 ++++++++++---------- qolmat/imputations/em_sampler.py | 18 +- qolmat/imputations/preprocessing.py | 42 +-- qolmat/utils/algebra.py | 83 ++++++ qolmat/utils/utils.py | 21 +- tests/benchmark/test_metrics.py | 99 ++++--- tests/imputations/test_preprocessing.py | 8 +- tests/utils/test_algebra.py | 31 +++ 12 files changed, 386 insertions(+), 248 deletions(-) create mode 100644 qolmat/utils/algebra.py create mode 100644 tests/utils/test_algebra.py diff --git a/HISTORY.rst b/HISTORY.rst index 46428cff..52e62f1d 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -10,6 +10,7 @@ History * Tutorial plot_tuto_categorical showcasing mixed type imputation * Titanic dataset added * accuracy metric implemented +* metrics.py rationalized, and split with algebra.py 0.1.3 (2024-03-07) ------------------ diff --git a/docs/imputers.rst b/docs/imputers.rst index a8e4552c..633024e6 100644 --- a/docs/imputers.rst +++ b/docs/imputers.rst @@ -3,24 +3,28 @@ Imputers All imputers can be found in the ``qolmat.imputations`` folder. -1. Simple (mean/median/shuffle) -------------------------------- -Imputes the missing values using the mean/median along each column or with a random value in each column. See the :class:`~qolmat.imputations.imputers.ImputerSimple` and :class:`~qolmat.imputations.imputers.ImputerShuffle` classes. +1. Simple (mean/median/mode) +---------------------------- +Imputes the missing values using a basic simple statistics: the mode (most frequent value) for the categorical columns, and the mean,median or mode (depending on the user parameter) for the numerical columns. See :class:`~qolmat.imputations.imputers.ImputerSimple`. -2. LOCF +2. Shuffle +---------- +Imputes the missing values using a random value sampled in the same column. See :class:`~qolmat.imputations.imputers.ImputerShuffle`. + +3. LOCF ------- -Imputes the missing values using the last observation carried forward. See the :class:`~qolmat.imputations.imputers.ImputerLOCF` class. +Imputes the missing values using the last observation carried forward. See :class:`~qolmat.imputations.imputers.ImputerLOCF`. -3. Time interpolation and TSA decomposition +4. Time interpolation and TSA decomposition ------------------------------------------- -Imputes missing using some interpolation strategies supported by `pd.Series.interpolate `_. It is done column by column. See the :class:`~qolmat.imputations.imputers.ImputerInterpolation` class. When data are temporal with clear seasonal decomposition, we can interpolate on the residuals instead of directly interpolate the raw data. Series are de-seasonalised based on `statsmodels.tsa.seasonal.seasonal_decompose `_, residuals are imputed via linear interpolation, then residuals are re-seasonalised. It is also done column by column. See the :class:`~qolmat.imputations.imputers.ImputerResiduals` class. +Imputes missing using some interpolation strategies supported by `pd.Series.interpolate `_. It is done column by column. See the :class:`~qolmat.imputations.imputers.ImputerInterpolation` class. When data are temporal with clear seasonal decomposition, we can interpolate on the residuals instead of directly interpolate the raw data. Series are de-seasonalised based on `statsmodels.tsa.seasonal.seasonal_decompose `_, residuals are imputed via linear interpolation, then residuals are re-seasonalised. It is also done column by column. See :class:`~qolmat.imputations.imputers.ImputerResiduals`. -4. MICE +5. MICE ------- Multiple Imputation by Chained Equation: multiple imputations based on ICE. It uses `IterativeImputer `_. See the :class:`~qolmat.imputations.imputers.ImputerMICE` class. -5. RPCA +6. RPCA ------- Robust Principal Component Analysis (RPCA) is a modification of the statistical procedure of PCA which allows to work with a data matrix :math:`\mathbf{D} \in \mathbb{R}^{n \times d}` containing missing values and grossly corrupted observations. We consider here the imputation task alone, but these methods can also tackle anomaly correction. @@ -46,7 +50,7 @@ The class :class:`RpcaNoisy` implements an recommanded improved version, which r with :math:`\mathbf{E} = \mathbf{D} - \mathbf{M} - \mathbf{A}`. See the :class:`~qolmat.imputations.imputers.ImputerRpcaNoisy` class for implementation details. -6. SoftImpute +7. SoftImpute ------------- SoftImpute is an iterative method for matrix completion that uses nuclear-norm regularization [11]. It is a faster alternative to RPCA, although it is much less robust due to the quadratic penalization. Given a matrix :math:`\mathbf{D} \in \mathbb{R}^{n \times d}` with observed entries indexed by the set :math:`\Omega`, this algorithm solves the following problem: @@ -56,11 +60,11 @@ SoftImpute is an iterative method for matrix completion that uses nuclear-norm r The imputed values are then given by the matrix :math:`M=LQ` on the unobserved data. See the :class:`~qolmat.imputations.imputers.ImputerSoftImpute` class for implementation details. -7. KNN +8. KNN ------ K-nearest neighbors, based on `KNNImputer `_. See the :class:`~qolmat.imputations.imputers.ImputerKNN` class. -8. EM sampler +9. EM sampler ------------- Imputes missing values via EM algorithm [5], and more precisely via MCEM algorithm [6]. See the :class:`~qolmat.imputations.imputers.ImputerEM` class. Suppose the data :math:`\mathbf{X}` has a density :math:`p_\theta` parametrized by some parameter :math:`\theta`. The EM algorithm allows to draw samples from this distribution by alternating between the expectation and maximization steps. @@ -104,7 +108,7 @@ Two parametric distributions are implemented: * :class:`~qolmat.imputations.em_sampler.VARpEM`: [7]: :math:`\mathbf{X} \in \mathbb{R}^{n \times d} \sim VAR_p(\nu, B_1, ..., B_p)` is generated by a VAR(p) process such that :math:`X_t = \nu + B_1 X_{t-1} + ... + B_p X_{t-p} + u_t` where :math:`\nu \in \mathbb{R}^d` is a vector of intercept terms, the :math:`B_i \in \mathbb{R}^{d \times d}` are the lags coefficient matrices and :math:`u_t` is white noise nonsingular covariance matrix :math:`\Sigma_u \mathbb{R}^{d \times d}`, so that :math:`\theta = (\nu, B_1, ..., B_p, \Sigma_u)`. -9. TabDDPM +10. TabDDPM ----------- :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` is a deep learning imputer based on Denoising Diffusion Probabilistic Models (DDPMs) [8] for handling multivariate tabular data. Our implementation mainly follows the works of [8, 9]. Diffusion models focus on modeling the process of data transitions from noisy and incomplete observations to the underlying true data. They include two main processes: diff --git a/examples/benchmark.md b/examples/benchmark.md index 45b201b5..e2fe8c1d 100644 --- a/examples/benchmark.md +++ b/examples/benchmark.md @@ -244,8 +244,8 @@ dfs_imputed["VAR_max"].groupby("station").min() ``` ```python tags=[] -# station = df_plot.index.get_level_values("station")[0] -station = "Huairou" +station = df_plot.index.get_level_values("station")[0] +# station = "Huairou" df_station = df_plot.loc[station] dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()} ``` @@ -362,7 +362,7 @@ comparison = comparator.Comparator( ) ``` -```python jupyter={"outputs_hidden": true} tags=[] +```python tags=[] generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=3, groups=('station',), subset=cols_to_impute, ratio_masked=ratio_masked) comparison = comparator.Comparator( @@ -393,7 +393,7 @@ plt.show() df_plot = df_data[cols_to_impute] ``` -```python jupyter={"outputs_hidden": true} tags=[] +```python tags=[] dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()} ``` diff --git a/examples/tutorials/plot_tuto_categorical.py b/examples/tutorials/plot_tuto_categorical.py index 0ab886b8..6940d50b 100644 --- a/examples/tutorials/plot_tuto_categorical.py +++ b/examples/tutorials/plot_tuto_categorical.py @@ -57,7 +57,7 @@ # - manage categorical features though one hot encoding # - manage missing features (native to the HistGradientBoosting) -pipestimator = preprocessing.make_robust_MixteHGB(allow_new=False) +pipestimator = preprocessing.make_robust_MixteHGB(avoid_new=True) imputer_hgb = ImputerRegressor(estimator=pipestimator, handler_nan="none") imputer_wrap_hgb = preprocessing.WrapperTransformer(imputer_hgb, bt) diff --git a/qolmat/benchmark/metrics.py b/qolmat/benchmark/metrics.py index dd72e612..3a2699b7 100644 --- a/qolmat/benchmark/metrics.py +++ b/qolmat/benchmark/metrics.py @@ -7,7 +7,9 @@ from sklearn import metrics as skm import dcor +from qolmat.utils import algebra, utils from qolmat.utils.exceptions import NotEnoughSamples +from numpy.linalg import LinAlgError EPS = np.finfo(float).eps @@ -48,12 +50,18 @@ def columnwise_metric( pd.Series Series of scores for all columns """ + try: + pd.testing.assert_index_equal(df1.columns, df2.columns) + except AssertionError: + raise ValueError( + f"Input dataframes do not have the same columns! ({df1.columns} != {df2.columns})" + ) if type_cols == "all": cols = df1.columns elif type_cols == "numerical": - cols = df1.select_dtypes(include=["number"]).columns + cols = _get_numerical_features(df1) elif type_cols == "categorical": - cols = df1.select_dtypes(exclude=["number"]).columns + cols = _get_categorical_features(df1) else: raise ValueError(f"Value {type_cols} is not valid for parameter `type_cols`!") values = {} @@ -83,13 +91,7 @@ def mean_squared_error(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFra ------- pd.Series """ - cols_numerical = _get_numerical_features(df1) - return columnwise_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], - skm.mean_squared_error, - ) + return columnwise_metric(df1, df2, df_mask, skm.mean_squared_error, type_cols="numerical") def root_mean_squared_error( @@ -110,13 +112,8 @@ def root_mean_squared_error( ------- pd.Series """ - cols_numerical = _get_numerical_features(df1) return columnwise_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], - skm.mean_squared_error, - squared=False, + df1, df2, df_mask, skm.mean_squared_error, type_cols="numerical", squared=False ) @@ -136,13 +133,7 @@ def mean_absolute_error(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFr ------- pd.Series """ - cols_numerical = _get_numerical_features(df1) - return columnwise_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], - skm.mean_absolute_error, - ) + return columnwise_metric(df1, df2, df_mask, skm.mean_absolute_error, type_cols="numerical") def mean_absolute_percentage_error( @@ -163,12 +154,8 @@ def mean_absolute_percentage_error( ------- pd.Series """ - cols_numerical = _get_numerical_features(df1) return columnwise_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], - skm.mean_absolute_percentage_error, + df1, df2, df_mask, skm.mean_absolute_percentage_error, type_cols="numerical" ) @@ -209,13 +196,45 @@ def weighted_mean_absolute_percentage_error( ------- pd.Series """ - return columnwise_metric(df1, df2, df_mask, _weighted_mean_absolute_percentage_error_1D) + return columnwise_metric( + df1, + df2, + df_mask, + _weighted_mean_absolute_percentage_error_1D, + type_cols="numerical", + ) -def accuracy(values1: pd.Series, values2: pd.Series) -> float: +def accuracy(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> pd.Series: """ Matching ratio beetween the two datasets. + Parameters + ---------- + df1 : pd.DataFrame + True dataframe + df2 : pd.DataFrame + Predicted dataframe + df_mask : pd.DataFrame + Elements of the dataframes to compute on + + Returns + ------- + pd.Series + """ + return columnwise_metric( + df1, + df2, + df_mask, + accuracy_1D, + type_cols="all", + ) + + +def accuracy_1D(values1: pd.Series, values2: pd.Series) -> float: + """ + Matching ratio beetween the set of values. + Parameters ---------- values1 : pd.Series @@ -352,13 +371,7 @@ def kolmogorov_smirnov_test( pd.Series KS test statistic """ - cols_numerical = _get_numerical_features(df1) - return columnwise_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], - kolmogorov_smirnov_test_1D, - ) + return columnwise_metric(df1, df2, df_mask, kolmogorov_smirnov_test_1D, type_cols="numerical") def _total_variance_distance_1D(df1: pd.Series, df2: pd.Series) -> float: @@ -439,9 +452,7 @@ def _get_correlation_pearson_matrix(df: pd.DataFrame, use_p_value: bool = True) matrix = np.zeros((len(df.columns), len(df.columns))) for idx_1, col_1 in enumerate(cols): for idx_2, col_2 in enumerate(cols): - res = scipy.stats.mstats.pearsonr( - df[col_1].array.reshape(-1, 1), df[col_2].array.reshape(-1, 1) - ) + res = scipy.stats.mstats.pearsonr(df[[col_1]].values, df[[col_2]].values) if use_p_value: matrix[idx_1, idx_2] = res[1] else: @@ -755,7 +766,6 @@ def sum_pairwise_distances( def frechet_distance( df1: pd.DataFrame, df2: pd.DataFrame, - df_mask: pd.DataFrame, ) -> float: """Compute the Fréchet distance between two dataframes df1 and df2 Frechet_distance = || mu_1 - mu_2 ||_2^2 + Tr(Sigma_1 + Sigma_2 - 2(Sigma_1 . Sigma_2)^(1/2)) @@ -770,8 +780,6 @@ def frechet_distance( true dataframe df2 : pd.DataFrame predicted dataframe - df_mask : pd.DataFrame - Mask indicating on which values the distance has to computed on Returns ------- @@ -782,35 +790,22 @@ def frechet_distance( if df1.shape != df2.shape: raise Exception("inputs have to be of same dimensions.") - df_true = df1[df_mask.any(axis=1)] - df_pred = df2[df_mask.any(axis=1)] - - std = (np.std(df_true) + np.std(df_pred) + EPS) / 2 - mu = (np.nanmean(df_true, axis=0) + np.nanmean(df_pred, axis=0)) / 2 - df_true = (df_true - mu) / std - df_pred = (df_pred - mu) / std + std = (np.std(df1) + np.std(df2) + EPS) / 2 + mu = (np.nanmean(df1, axis=0) + np.nanmean(df2, axis=0)) / 2 + df1 = (df1 - mu) / std + df2 = (df2 - mu) / std - mu_true = np.nanmean(df_true, axis=0) - sigma_true = np.ma.cov(np.ma.masked_invalid(df_true), rowvar=False).data - mu_pred = np.nanmean(df_pred, axis=0) - sigma_pred = np.ma.cov(np.ma.masked_invalid(df_pred), rowvar=False).data + means1, cov1 = utils.nan_mean_cov(df1.values) + means2, cov2 = utils.nan_mean_cov(df2.values) - ssdiff = np.sum((mu_true - mu_pred) ** 2.0) - product = np.array(sigma_true @ sigma_pred) - if product.ndim < 2: - product = product.reshape(-1, 1) - covmean = scipy.linalg.sqrtm(product) - if np.iscomplexobj(covmean): - covmean = covmean.real - frechet_dist = ssdiff + np.trace(sigma_true + sigma_pred - 2.0 * covmean) - - return frechet_dist / df_true.shape[0] + return algebra.frechet_distance_exact(means1, cov1, means2, cov2) def frechet_distance_pattern( df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame, + min_n_rows: int = 10, ) -> pd.Series: """Frechet distance computed using a pattern decomposition @@ -822,15 +817,23 @@ def frechet_distance_pattern( Second empirical ditribution df_mask : pd.DataFrame Mask indicating on which values the distance has to computed on + min_n_rows: int + Minimum number of rows for a KL estimation Returns ------- pd.Series Series of computed metrics """ - cols_numerical = _get_numerical_features(df1) - distance = frechet_distance(df1[cols_numerical], df2[cols_numerical], df_mask[cols_numerical]) - return pd.Series(distance, index=["All"]) + + return pattern_based_weighted_mean_metric( + df1, + df2, + df_mask, + frechet_distance, + min_n_rows=min_n_rows, + type_cols="numerical", + ) def kl_divergence_1D(df1: pd.Series, df2: pd.Series) -> float: @@ -858,39 +861,7 @@ def kl_divergence_1D(df1: pd.Series, df2: pd.Series) -> float: return scipy.stats.entropy(p + EPS, q + EPS) -def kl_divergence_gaussian_exact( - mean1: pd.Series, cov1: pd.DataFrame, mean2: pd.Series, cov2: pd.DataFrame -) -> float: - """Exact Kullback-Leibler divergence computed between two multivariate normal distributions - - Parameters - ---------- - mean1: pd.Series - Mean of the first distribution - cov1: pd.DataFrame - Covariance matrx of the first distribution - mean2: pd.Series - Mean of the second distribution - cov2: pd.DataFrame - Covariance matrx of the second distribution - Returns - ------- - float - Kulback-Leibler divergence - """ - n_variables = len(mean1) - L1, lower1 = scipy.linalg.cho_factor(cov1) - L2, lower2 = scipy.linalg.cho_factor(cov2) - M = scipy.linalg.solve(L2, L1) - y = scipy.linalg.solve(L2, mean2 - mean1) - norm_M = (M**2).sum().sum() - norm_y = (y**2).sum() - term_diag_L = 2 * np.sum(np.log(np.diagonal(L2) / np.diagonal(L1))) - div_kl = 0.5 * (norm_M - n_variables + norm_y + term_diag_L) - return div_kl - - -def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.Series) -> float: +def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame) -> float: """Kullback-Leibler divergence estimation based on a Gaussian approximation of both empirical distributions @@ -900,29 +871,29 @@ def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.Ser First empirical distribution df2 : pd.DataFrame Second empirical distribution - df_mask: pd.DataFrame - Mask indicating on what values the divergence should be computed Returns ------- pd.Series Series of estimated metrics """ - df1 = df1[df_mask.any(axis=1)] - df2 = df2[df_mask.any(axis=1)] - cov1 = df1.cov() - cov2 = df2.cov() - mean1 = df1.mean() - mean2 = df2.mean() - - div_kl = kl_divergence_gaussian_exact(mean1, cov1, mean2, cov2) + cov1 = df1.cov().values + cov2 = df2.cov().values + means1 = np.array(df1.mean()) + means2 = np.array(df2.mean()) + try: + div_kl = algebra.kl_divergence_gaussian_exact(means1, cov1, means2, cov2) + except LinAlgError: + raise ValueError( + "Provided datasets have degenerate colinearities, KL-divergence cannot be computed!" + ) return div_kl -def kl_divergence( +def kl_divergence_pattern( df1: pd.DataFrame, df2: pd.DataFrame, - df_mask: pd.Series, + df_mask: pd.DataFrame, method: str = "columnwise", min_n_rows: int = 10, ) -> pd.Series: @@ -958,21 +929,15 @@ def kl_divergence( Consider using a larger dataset of lowering the parameter `min_n_rows`. """ if method == "columnwise": - cols_numerical = _get_numerical_features(df1) - return columnwise_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], - kl_divergence_1D, - ) + return columnwise_metric(df1, df2, df_mask, kl_divergence_1D, type_cols="numerical") elif method == "gaussian": - cols_numerical = _get_numerical_features(df1) return pattern_based_weighted_mean_metric( - df1[cols_numerical], - df2[cols_numerical], - df_mask[cols_numerical], + df1, + df2, + df_mask, kl_divergence_gaussian, min_n_rows=min_n_rows, + type_cols="numerical", ) else: raise AssertionError( @@ -981,7 +946,7 @@ def kl_divergence( ) -def distance_anticorr(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> float: +def distance_anticorr(df1: pd.DataFrame, df2: pd.DataFrame) -> float: """Score based on the distance anticorrelation between two empirical distributions. The theoretical basis can be found on dcor documentation: https://dcor.readthedocs.io/en/latest/theory.html @@ -992,25 +957,57 @@ def distance_anticorr(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFram Dataframe representing the first empirical distribution df2 : pd.DataFrame Dataframe representing the second empirical distribution - df_mask: pd.DataFrame - Mask indicating on what values the divergence should be computed Returns ------- float Distance correlation score """ - df1 = df1.loc[df_mask.any(axis=1)] - df2 = df2.loc[df_mask.any(axis=1)] return (1 - dcor.distance_correlation(df1.values, df2.values)) / 2 +def distance_anticorr_pattern( + df1: pd.DataFrame, + df2: pd.DataFrame, + df_mask: pd.DataFrame, + min_n_rows: int = 10, +) -> pd.Series: + """Correlation distance computed using a pattern decomposition + + Parameters + ---------- + df1 : pd.DataFrame + First empirical ditribution + df2 : pd.DataFrame + Second empirical ditribution + df_mask : pd.DataFrame + Mask indicating on which values the distance has to computed on + min_n_rows: int + Minimum number of rows for a KL estimation + + Returns + ------- + pd.Series + Series of computed metrics + """ + + return pattern_based_weighted_mean_metric( + df1, + df2, + df_mask, + distance_anticorr, + min_n_rows=min_n_rows, + type_cols="numerical", + ) + + def pattern_based_weighted_mean_metric( df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame, metric: Callable, min_n_rows: int = 10, + type_cols: str = "all", **kwargs, ) -> pd.Series: """Compute a mean score based on missing patterns. @@ -1035,22 +1032,34 @@ def pattern_based_weighted_mean_metric( pd.Series _description_ """ + if type_cols == "all": + cols = df1.columns + elif type_cols == "numerical": + cols = df1.select_dtypes(include=["number"]).columns + elif type_cols == "categorical": + cols = df1.select_dtypes(exclude=["number"]).columns + else: + raise ValueError(f"Value {type_cols} is not valid for parameter `type_cols`!") + if np.any(df_mask & df1.isna()): + raise ValueError("The argument df1 has missing values on the mask!") + if np.any(df_mask & df2.isna()): + raise ValueError("The argument df2 has missing values on the mask!") + rows_mask = df_mask.any(axis=1) scores = [] weights = [] - df1 = df1.loc[df_mask.any(axis=1)] - df2 = df2.loc[df_mask.any(axis=1)] - df_nan = df1.notna() + df1 = df1[cols].loc[rows_mask] + df2 = df2[cols].loc[rows_mask] + df_mask = df_mask[cols].loc[rows_mask] max_num_row = 0 - for tup_pattern, df_nan_pattern in df_nan.groupby(df_nan.columns.tolist()): - ind_pattern = df_nan_pattern.index + for tup_pattern, df_mask_pattern in df_mask.groupby(df_mask.columns.tolist()): + ind_pattern = df_mask_pattern.index df1_pattern = df1.loc[ind_pattern, list(tup_pattern)] max_num_row = max(max_num_row, len(df1_pattern)) if not any(tup_pattern) or len(df1_pattern) < min_n_rows: continue df2_pattern = df2.loc[ind_pattern, list(tup_pattern)] - df_mask_pattern = df_mask.loc[ind_pattern, list(tup_pattern)] weights.append(len(df1_pattern) / len(df1)) - scores.append(metric(df1_pattern, df2_pattern, df_mask_pattern, **kwargs)) + scores.append(metric(df1_pattern, df2_pattern, **kwargs)) if len(scores) == 0: raise NotEnoughSamples(max_num_row, min_n_rows) return pd.Series(sum([s * w for s, w in zip(scores, weights)]), index=["All"]) @@ -1062,20 +1071,14 @@ def get_metric(name: str) -> Callable: "rmse": root_mean_squared_error, "mae": mean_absolute_error, "wmape": weighted_mean_absolute_percentage_error, - "accuracy": partial( - columnwise_metric, - metric=accuracy, - ), + "accuracy": accuracy, "wasserstein_columnwise": dist_wasserstein, - "KL_columnwise": partial(kl_divergence, method="columnwise"), - "KL_gaussian": partial(kl_divergence, method="gaussian"), + "KL_columnwise": partial(kl_divergence_pattern, method="columnwise"), + "KL_gaussian": partial(kl_divergence_pattern, method="gaussian"), "ks_test": kolmogorov_smirnov_test, "correlation_diff": mean_difference_correlation_matrix_numerical_features, "energy": sum_energy_distances, "frechet": frechet_distance_pattern, - "dist_corr_pattern": partial( - pattern_based_weighted_mean_metric, - metric=distance_anticorr, - ), + "dist_corr_pattern": distance_anticorr_pattern, } return dict_metrics[name] diff --git a/qolmat/imputations/em_sampler.py b/qolmat/imputations/em_sampler.py index cd41c86d..785b206a 100644 --- a/qolmat/imputations/em_sampler.py +++ b/qolmat/imputations/em_sampler.py @@ -466,14 +466,14 @@ def _check_conditionning(self, X: NDArray): IllConditioned Data matrix is ill-conditioned due to colinear columns. """ - n_rows, n_cols = X.shape + n_samples, n_cols = X.shape # if n_rows == 1 the function np.cov returns a float - if n_rows == 1: - min_sv = 0 - else: - cov = np.cov(X, bias=True, rowvar=False).reshape(n_cols, -1) - _, sv, _ = spl.svd(cov) - min_sv = min(np.sqrt(sv)) + if n_samples == 1: + raise ValueError("EM cannot be fitted when n_samples = 1!") + + cov = np.cov(X, bias=True, rowvar=False).reshape(n_cols, -1) + _, sv, _ = spl.svd(cov) + min_sv = min(np.sqrt(sv)) if min_sv < self.min_std: warnings.warn( f"The covariance matrix is ill-conditioned, indicating high-colinearity: the " @@ -481,7 +481,6 @@ def _check_conditionning(self, X: NDArray): f"min_std ({min_sv} < {self.min_std}). Consider removing columns of decreasing " f"the threshold." ) - # raise IllConditioned(min_sv, self.min_std) class MultiNormalEM(EM): @@ -683,8 +682,7 @@ def fit_parameters_with_missingness(self, X: NDArray): X : NDArray Data matrix with missingness """ - self.means = np.nanmean(X, axis=0) - self.cov = utils.nancov(X) + self.means, self.cov = utils.nan_mean_cov(X) self.cov_inv = np.linalg.pinv(self.cov) def set_parameters(self, means: NDArray, cov: NDArray): diff --git a/qolmat/imputations/preprocessing.py b/qolmat/imputations/preprocessing.py index 77d500a0..29d48e58 100644 --- a/qolmat/imputations/preprocessing.py +++ b/qolmat/imputations/preprocessing.py @@ -35,17 +35,10 @@ class MixteHGBM(RegressorMixin, BaseEstimator): A custom scikit-learn estimator implementing a mixed model using HistGradientBoostingClassifier for string target data and HistGradientBoostingRegressor for numeric target data. - - Parameters - ---------- - allow_new : bool, default=True - Whether to allow new categories in numerical target data. If false the predictions are - mapped to the closest existing value. """ - def __init__(self, allow_new=True): + def __init__(self): super().__init__() - self.allow_new = allow_new def set_model_parameters(self, **args_model): """ @@ -150,11 +143,12 @@ def fit(self, X: NDArray, y: Optional[NDArray] = None) -> Self: self.feature_names_in_ = df.columns self.n_features_in_ = len(df.columns) self.dict_df_bins_: Dict[Hashable, pd.DataFrame] = dict() - cols = df.columns if self.cols is None else self.cols + if self.cols is None: + cols = df.select_dtypes(include="number").columns + else: + cols = self.cols for col in cols: values = df[col] - if not pd.api.types.is_numeric_dtype(values): - raise TypeError values = values.dropna() df_bins = pd.DataFrame({"value": np.sort(values.unique())}) df_bins["min"] = (df_bins["value"] + df_bins["value"].shift()) / 2 @@ -297,15 +291,17 @@ def transform(self, X: NDArray) -> NDArray: def make_pipeline_mixte_preprocessing( - scale_numerical: bool = True, -) -> BaseEstimator: + scale_numerical: bool = False, avoid_new: bool = False +) -> Pipeline: """ Create a preprocessing pipeline managing mixed type data by one hot encoding categorical data. Parameters ---------- - scale_numerical : bool, default=True + scale_numerical : bool, default=False Whether to scale numerical features. + avoid_new : bool, default=False + Whether to forbid new numerical values. Returns ------- @@ -315,13 +311,17 @@ def make_pipeline_mixte_preprocessing( transformers: List[Tuple] = [] if scale_numerical: transformers += [("num", StandardScaler(), selector(dtype_include=np.number))] + ohe = OneHotEncoder(handle_unknown="ignore", use_cat_names=True) transformers += [("cat", ohe, selector(dtype_exclude=np.number))] - preprocessor = ColumnTransformer(transformers=transformers).set_output(transform="pandas") + col_transformer = ColumnTransformer(transformers=transformers).set_output(transform="pandas") + preprocessor = Pipeline(steps=[("col_transformer", col_transformer)]) + if avoid_new: + preprocessor.steps.append(("bins", BinTransformer())) return preprocessor -def make_robust_MixteHGB(scale_numerical: bool = True, allow_new: bool = True) -> Pipeline: +def make_robust_MixteHGB(scale_numerical: bool = False, avoid_new: bool = False) -> Pipeline: """ Create a robust pipeline for MixteHGBM by one hot encoding categorical features. This estimator is intended for use in ImputerRegressor to deal with mixed type data. @@ -332,10 +332,10 @@ def make_robust_MixteHGB(scale_numerical: bool = True, allow_new: bool = True) - Parameters ---------- - scale_numerical : bool, default=True + scale_numerical : bool, default=False Whether to scale numerical features. - allow_new : bool, default=True - Whether to allow new categories. + avoid_new : bool, default=False + Whether to forbid new numerical values. Returns ------- @@ -343,12 +343,12 @@ def make_robust_MixteHGB(scale_numerical: bool = True, allow_new: bool = True) - A robust pipeline for MixteHGBM. """ preprocessor = make_pipeline_mixte_preprocessing( - scale_numerical=scale_numerical, + scale_numerical=scale_numerical, avoid_new=avoid_new ) robust_MixteHGB = Pipeline( steps=[ ("preprocessor", preprocessor), - ("estimator", MixteHGBM(allow_new=allow_new)), + ("estimator", MixteHGBM()), ] ) diff --git a/qolmat/utils/algebra.py b/qolmat/utils/algebra.py new file mode 100644 index 00000000..9e2af1a6 --- /dev/null +++ b/qolmat/utils/algebra.py @@ -0,0 +1,83 @@ +import numpy as np +import scipy +from numpy.typing import NDArray, ArrayLike + + +def frechet_distance_exact( + means1: NDArray, + cov1: NDArray, + means2: NDArray, + cov2: NDArray, +) -> float: + """Compute the Fréchet distance between two dataframes df1 and df2 + Frechet_distance = || mu_1 - mu_2 ||_2^2 + Tr(Sigma_1 + Sigma_2 - 2(Sigma_1 . Sigma_2)^(1/2)) + It is normalized, df1 and df2 are first scaled by a factor (std(df1) + std(df2)) / 2 + and then centered around (mean(df1) + mean(df2)) / 2 + The result is divided by the number of samples to get an homogeneous result. + Based on: Dowson, D. C., and BV666017 Landau. "The Fréchet distance between multivariate normal + distributions." Journal of multivariate analysis 12.3 (1982): 450-455. + + Parameters + ---------- + means1 : NDArray + Means of the first distribution + cov1 : NDArray + Covariance matrix of the first distribution + means2 : NDArray + Means of the second distribution + cov2 : NDArray + Covariance matrix of the second distribution + + Returns + ------- + float + Frechet distance + """ + n = len(means1) + if (means2.shape != (n,)) or (cov1.shape != (n, n)) or (cov2.shape != (n, n)): + raise ValueError("Inputs have to be of same dimensions.") + + ssdiff = np.sum((means1 - means2) ** 2.0) + product = np.array(cov1 @ cov2) + if product.ndim < 2: + product = product.reshape(-1, 1) + covmean = scipy.linalg.sqrtm(product) + if np.iscomplexobj(covmean): + covmean = covmean.real + frechet_dist = ssdiff + np.trace(cov1 + cov2 - 2.0 * covmean) + + return frechet_dist / n + + +def kl_divergence_gaussian_exact( + means1: NDArray, cov1: NDArray, means2: NDArray, cov2: NDArray +) -> float: + """ + Exact Kullback-Leibler divergence computed between two multivariate normal distributions + Based on https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + + Parameters + ---------- + means1: NDArray + Mean of the first distribution + cov1: NDArray + Covariance matrx of the first distribution + means2: NDArray + Mean of the second distribution + cov2: NDArray + Covariance matrx of the second distribution + Returns + ------- + float + Kulback-Leibler divergence + """ + n_variables = len(means1) + L1, _ = scipy.linalg.cho_factor(cov1) + L2, _ = scipy.linalg.cho_factor(cov2) + M = scipy.linalg.solve(L2, L1) + y = scipy.linalg.solve(L2, means2 - means1) + norm_M = (M**2).sum().sum() + norm_y = (y**2).sum() + term_diag_L = 2 * np.sum(np.log(np.diagonal(L2) / np.diagonal(L1))) + div_kl = 0.5 * (norm_M - n_variables + norm_y + term_diag_L) + return div_kl diff --git a/qolmat/utils/utils.py b/qolmat/utils/utils.py index d036d9d5..36771ecc 100644 --- a/qolmat/utils/utils.py +++ b/qolmat/utils/utils.py @@ -250,16 +250,11 @@ def create_lag_matrices(X: NDArray, p: int) -> Tuple[NDArray, NDArray]: return Z, Y -def nancov(X: NDArray) -> NDArray: - _, n_cols = X.shape - cov = np.nan * np.zeros((n_cols, n_cols)) - mask = np.isnan(X) - for i in range(n_cols): - Di = X[:, i] - np.nanmean(X[:, i]) - for j in range(n_cols): - select = (~mask[:, i]) & (~mask[:, j]) - Di = X[select, i] - np.mean(X[select, i]) - Dj = X[select, j] - np.mean(X[select, j]) - cov[i, j] = np.nanmean(Di * Dj) - cov = impute_nans(cov, method="zeros") - return cov +def nan_mean_cov(X: NDArray) -> Tuple[NDArray, NDArray]: + _, n_variables = X.shape + means = np.nanmean(X, axis=0) + cov = np.ma.cov(np.ma.masked_invalid(X), rowvar=False).data + print(cov.shape) + print(X.shape) + cov = cov.reshape(n_variables, n_variables) + return means, cov diff --git a/tests/benchmark/test_metrics.py b/tests/benchmark/test_metrics.py index df08fe8e..b2b1f4b7 100644 --- a/tests/benchmark/test_metrics.py +++ b/tests/benchmark/test_metrics.py @@ -2,6 +2,7 @@ # # Evaluation metrics # # ###################### +from math import exp import numpy as np from numpy import random as npr import pandas as pd @@ -97,6 +98,18 @@ def test_weighted_mean_absolute_percentage_error( np.testing.assert_allclose(result, expected, atol=1e-3) +@pytest.mark.parametrize("df1", [df_incomplete]) +@pytest.mark.parametrize("df2", [df_imputed]) +@pytest.mark.parametrize("df_mask", [df_mask]) +def test_accuracy(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> None: + result = metrics.accuracy(df1, df1, df_mask) + expected = pd.Series([1.0, 1.0], index=["col1", "col2"]) + pd.testing.assert_series_equal(result, expected) + result = metrics.accuracy(df1, df2, df_mask) + expected = pd.Series([0.5, 0.0], index=["col1", "col2"]) + pd.testing.assert_series_equal(result, expected, atol=1e-3) + + @pytest.mark.parametrize("df1", [df_incomplete]) @pytest.mark.parametrize("df2", [df_imputed]) @pytest.mark.parametrize("df_mask", [df_mask]) @@ -110,15 +123,19 @@ def test_wasserstein_distance(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd. @pytest.mark.parametrize("df1", [df_incomplete]) @pytest.mark.parametrize("df2", [df_imputed]) @pytest.mark.parametrize("df_mask", [df_mask]) -def test_kl_divergence_columnwise( - df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame -) -> None: - result = metrics.kl_divergence(df1, df1, df_mask, method="columnwise") +def test_kl_divergence(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> None: + result = metrics.kl_divergence_pattern(df1, df1, df_mask, method="columnwise") expected = pd.Series([0.0, 0.0], index=["col1", "col2"]) - np.testing.assert_allclose(result, expected, atol=1e-3) - result = metrics.kl_divergence(df1, df2, df_mask, method="columnwise") + pd.testing.assert_series_equal(result, expected, atol=1e-3) + + result = metrics.kl_divergence_pattern(df1, df2, df_mask, method="columnwise") expected = pd.Series([18.945, 36.637], index=["col1", "col2"]) - np.testing.assert_allclose(result, expected, atol=1e-3) + pd.testing.assert_series_equal(result, expected, atol=1e-3) + + df_nonan = df1.notna() + result = metrics.kl_divergence_pattern(df1, df2, df_nonan, method="gaussian", min_n_rows=2) + expected = pd.Series([1.029], index=["All"]) + pd.testing.assert_series_equal(result, expected, atol=1e-3) @pytest.mark.parametrize("df1", [df_incomplete]) @@ -127,22 +144,22 @@ def test_kl_divergence_columnwise( def test_kl_divergence_gaussian( df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame ) -> None: - result = metrics.kl_divergence_gaussian(df1, df1, df_mask) - np.testing.assert_allclose(result, 0, atol=1e-3) + result = metrics.kl_divergence_gaussian(df1, df1) + np.testing.assert_almost_equal(result, 0, decimal=3) - result = metrics.kl_divergence_gaussian(df1, df2, df_mask) - np.testing.assert_allclose(result, 1.371, atol=1e-3) + result = metrics.kl_divergence_gaussian(df1, df2) + expected = 0.669308 + np.testing.assert_almost_equal(result, expected, decimal=3) @pytest.mark.parametrize("df1", [df_incomplete]) @pytest.mark.parametrize("df2", [df_imputed]) -@pytest.mark.parametrize("df_mask", [df_mask]) -def test_frechet_distance(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> None: - result = metrics.frechet_distance(df1, df1, df_mask) +def test_frechet_distance(df1: pd.DataFrame, df2: pd.DataFrame) -> None: + result = metrics.frechet_distance(df1, df1) np.testing.assert_allclose(result, 0, atol=1e-3) - result = metrics.frechet_distance(df1, df2, df_mask) - np.testing.assert_allclose(result, 0.253, atol=1e-3) + result = metrics.frechet_distance(df1, df2) + np.testing.assert_allclose(result, 0.134, atol=1e-3) @pytest.mark.parametrize("df1", [df_incomplete]) @@ -303,7 +320,7 @@ def test_exception_raise_different_shapes( with pytest.raises(Exception): metrics.mean_difference_correlation_matrix_numerical_features(df1, df2, df_mask) with pytest.raises(Exception): - metrics.frechet_distance(df1, df2, df_mask) + metrics.frechet_distance(df1, df2) @pytest.mark.parametrize("df1", [df_incomplete_cat]) @@ -344,14 +361,6 @@ def test_value_error_get_correlation_f_oneway_matrix( ).equals(pd.Series([np.nan], index=["col1"])) -@pytest.mark.parametrize("df1", [df_incomplete]) -@pytest.mark.parametrize("df2", [df_imputed]) -@pytest.mark.parametrize("df_mask", [df_mask]) -def test_distance_anticorr(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> None: - result = metrics.distance_anticorr(df1, df2, df_mask) - np.testing.assert_allclose(result, 1.1e-4, rtol=1e-2) - - @pytest.mark.parametrize("df1", [df_incomplete]) @pytest.mark.parametrize("df2", [df_imputed]) @pytest.mark.parametrize("df_mask", [df_mask]) @@ -359,29 +368,39 @@ def test_pattern_based_weighted_mean_metric( df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame ) -> None: with pytest.raises(NotEnoughSamples): - metrics.pattern_based_weighted_mean_metric( - df1, df2, df_mask, metric=metrics.distance_anticorr, min_n_rows=5 - ) + metrics.distance_anticorr_pattern(df1, df2, df_mask, min_n_rows=5) - expected = pd.Series([1.1e-4], index=["All"]) - result = metrics.pattern_based_weighted_mean_metric( - df1, df2, df_mask, metric=metrics.distance_anticorr, min_n_rows=1 - ) + expected = pd.Series([1 / 6], index=["All"]) + result = metrics.distance_anticorr_pattern(df1, df2, df_mask, min_n_rows=1) np.testing.assert_allclose(result, expected, rtol=1e-2) rng = npr.default_rng(123) df_gauss1 = pd.DataFrame(rng.multivariate_normal([0, 0], [[1, 0.2], [0.2, 2]], size=100)) df_gauss2 = pd.DataFrame(rng.multivariate_normal([0, 1], [[1, 0.2], [0.2, 2]], size=100)) -df_mask = pd.DataFrame(np.full_like(df_gauss1, True)) +df_mask_gauss = pd.DataFrame(np.full_like(df_gauss1, True)) + + +def test_pattern_mae_comparison(mocker) -> None: + # def mock_metric(values1: pd.Series, values2: pd.Series) -> float: + # call_count += 1 + # return 0 -def test_pattern_mae_comparison() -> None: - def fun_mean_mae(df_gauss1, df_gauss2, df_mask) -> float: - return metrics.mean_squared_error(df_gauss1, df_gauss2, df_mask).mean() + mock_metric = mocker.patch("qolmat.benchmark.metrics.accuracy_1D", return_value=0) + # def fun_mean_mae(df_gauss1, df_gauss2, df_mask_gauss) -> float: + # return metrics.mean_squared_error(df_gauss1, df_gauss2, df_mask_gauss).mean() - result1 = fun_mean_mae(df_gauss1, df_gauss2, df_mask) - result2 = metrics.pattern_based_weighted_mean_metric( - df_gauss1, df_gauss2, df_mask, metric=fun_mean_mae, min_n_rows=1 + print(df_mask) + df_nonan = df_incomplete.notna() + result = metrics.pattern_based_weighted_mean_metric( + df_incomplete, df_imputed, df_nonan, metric=mock_metric, min_n_rows=1 ) - np.testing.assert_allclose(result1, result2, rtol=1e-2) + print(result) + assert mock_metric.call_count == 2 + + +def test_get_metric(): + expected = metrics.accuracy(df_incomplete, df_imputed, df_mask) + result = metrics.get_metric("accuracy")(df_incomplete, df_imputed, df_mask) + pd.testing.assert_series_equal(expected, result) diff --git a/tests/imputations/test_preprocessing.py b/tests/imputations/test_preprocessing.py index cfc25494..30b55bd3 100644 --- a/tests/imputations/test_preprocessing.py +++ b/tests/imputations/test_preprocessing.py @@ -221,8 +221,12 @@ def test_make_robust_MixteHGB(robust_mixte_hgb_model): # Ensure the pipeline is constructed correctly assert isinstance(robust_mixte_hgb_model, Pipeline) - # Ensure the preprocessor in the pipeline is of type ColumnTransformer - assert isinstance(robust_mixte_hgb_model.named_steps["preprocessor"], ColumnTransformer) + dict_steps = robust_mixte_hgb_model.named_steps + assert len(dict_steps) == 2 + # Ensure the preprocessor in the pipeline is of type Pipeline + assert isinstance(dict_steps["preprocessor"], Pipeline) + # Ensure the estimator in the pipeline is of type MixteHGBM + assert isinstance(dict_steps["estimator"], MixteHGBM) # Test fitting and predicting with numeric target X_train, X_test, y_train, y_test = train_test_split( diff --git a/tests/utils/test_algebra.py b/tests/utils/test_algebra.py new file mode 100644 index 00000000..45a508c8 --- /dev/null +++ b/tests/utils/test_algebra.py @@ -0,0 +1,31 @@ +import numpy as np +from sympy import diag + +from qolmat.utils import algebra + + +def test_frechet_distance_exact(): + means1 = np.array([0, 1, 3]) + stds = np.array([1, 1, 1]) + cov1 = np.diag(stds**2) + + means2 = np.array([0, -1, 1]) + cov2 = np.eye(3, 3) + + expected = np.sum((means2 - means1) ** 2) + np.sum((np.sqrt(stds) - 1) ** 2) + expected /= 3 + result = algebra.frechet_distance_exact(means1, cov1, means2, cov2) + np.testing.assert_almost_equal(result, expected, decimal=3) + + +def test_kl_divergence_gaussian_exact(): + means1 = np.array([0, 1, 3]) + stds = np.array([1, 2, 3]) + cov1 = np.diag(stds**2) + + means2 = np.array([0, -1, 1]) + cov2 = np.eye(3, 3) + + expected = (np.sum(stds**2 - np.log(stds**2) - 1 + (means2 - means1) ** 2)) / 2 + result = algebra.kl_divergence_gaussian_exact(means1, cov1, means2, cov2) + np.testing.assert_almost_equal(result, expected, decimal=3) From a2edc535168e8817bcf5f1d1da770f9e2d8b61a5 Mon Sep 17 00:00:00 2001 From: Julien Roussel <3178729-JulienRoussel77@users.noreply.gitlab.com> Date: Mon, 15 Apr 2024 10:24:29 +0200 Subject: [PATCH 2/3] frechet distance refacto --- qolmat/benchmark/metrics.py | 104 +++++++++++--------------------- qolmat/utils/utils.py | 52 +++++++++++++++- tests/benchmark/test_metrics.py | 14 ++--- 3 files changed, 92 insertions(+), 78 deletions(-) diff --git a/qolmat/benchmark/metrics.py b/qolmat/benchmark/metrics.py index 3a2699b7..b68d3e6b 100644 --- a/qolmat/benchmark/metrics.py +++ b/qolmat/benchmark/metrics.py @@ -59,9 +59,9 @@ def columnwise_metric( if type_cols == "all": cols = df1.columns elif type_cols == "numerical": - cols = _get_numerical_features(df1) + cols = utils._get_numerical_features(df1) elif type_cols == "categorical": - cols = _get_categorical_features(df1) + cols = utils._get_categorical_features(df1) else: raise ValueError(f"Value {type_cols} is not valid for parameter `type_cols`!") values = {} @@ -282,56 +282,6 @@ def dist_wasserstein( ) -def _get_numerical_features(df1: pd.DataFrame) -> List[str]: - """Get numerical features from dataframe - - Parameters - ---------- - df1 : pd.DataFrame - - Returns - ------- - List[str] - List of numerical features - - Raises - ------ - Exception - No numerical feature is found - """ - cols_numerical = df1.select_dtypes(include=np.number).columns.tolist() - if len(cols_numerical) == 0: - raise Exception("No numerical feature is found.") - else: - return cols_numerical - - -def _get_categorical_features(df1: pd.DataFrame) -> List[str]: - """Get categorical features from dataframe - - Parameters - ---------- - df1 : pd.DataFrame - - Returns - ------- - List[str] - List of categorical features - - Raises - ------ - Exception - No categorical feature is found - """ - - cols_numerical = df1.select_dtypes(include=np.number).columns.tolist() - cols_categorical = [col for col in df1.columns.to_list() if col not in cols_numerical] - if len(cols_categorical) == 0: - raise Exception("No categorical feature is found.") - else: - return cols_categorical - - def kolmogorov_smirnov_test_1D(df1: pd.Series, df2: pd.Series) -> float: """Compute KS test statistic of the two-sample Kolmogorov-Smirnov test for goodness of fit. See more in https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html. @@ -418,7 +368,7 @@ def total_variance_distance( pd.Series Total variance distance """ - cols_categorical = _get_categorical_features(df1) + cols_categorical = utils._get_categorical_features(df1) return columnwise_metric( df1[cols_categorical], df2[cols_categorical], @@ -491,7 +441,7 @@ def mean_difference_correlation_matrix_numerical_features( _check_same_number_columns(df1, df2) - cols_numerical = _get_numerical_features(df1) + cols_numerical = utils._get_numerical_features(df1) df_corr1 = _get_correlation_pearson_matrix(df1[cols_numerical], use_p_value=use_p_value) df_corr2 = _get_correlation_pearson_matrix(df2[cols_numerical], use_p_value=use_p_value) @@ -560,7 +510,7 @@ def mean_difference_correlation_matrix_categorical_features( _check_same_number_columns(df1, df2) - cols_categorical = _get_categorical_features(df1) + cols_categorical = utils._get_categorical_features(df1) df_corr1 = _get_correlation_chi2_matrix(df1[cols_categorical], use_p_value=use_p_value) df_corr2 = _get_correlation_chi2_matrix(df2[cols_categorical], use_p_value=use_p_value) @@ -635,8 +585,8 @@ def mean_diff_corr_matrix_categorical_vs_numerical_features( _check_same_number_columns(df1, df2) - cols_categorical = _get_categorical_features(df1) - cols_numerical = _get_numerical_features(df1) + cols_categorical = utils._get_categorical_features(df1) + cols_numerical = utils._get_numerical_features(df1) df_corr1 = _get_correlation_f_oneway_matrix( df1, cols_categorical, cols_numerical, use_p_value=use_p_value ) @@ -763,10 +713,10 @@ def sum_pairwise_distances( ########################### -def frechet_distance( +def frechet_distance_base( df1: pd.DataFrame, df2: pd.DataFrame, -) -> float: +) -> pd.Series: """Compute the Fréchet distance between two dataframes df1 and df2 Frechet_distance = || mu_1 - mu_2 ||_2^2 + Tr(Sigma_1 + Sigma_2 - 2(Sigma_1 . Sigma_2)^(1/2)) It is normalized, df1 and df2 are first scaled by a factor (std(df1) + std(df2)) / 2 @@ -783,8 +733,8 @@ def frechet_distance( Returns ------- - float - frechet distance + pd.Series + Frechet distance in a Series object """ if df1.shape != df2.shape: @@ -798,16 +748,23 @@ def frechet_distance( means1, cov1 = utils.nan_mean_cov(df1.values) means2, cov2 = utils.nan_mean_cov(df2.values) - return algebra.frechet_distance_exact(means1, cov1, means2, cov2) + distance = algebra.frechet_distance_exact(means1, cov1, means2, cov2) + return pd.Series(distance, index=["All"]) -def frechet_distance_pattern( +def frechet_distance( df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame, + method: str = "single", min_n_rows: int = 10, ) -> pd.Series: - """Frechet distance computed using a pattern decomposition + """ + Frechet distance computed using a pattern decomposition. Several variant are implemented: + - the `single` method relies on a single estimation of the means and covariance matrix. It is + relevent for MCAR data. + - the `pattern`method relies on the aggregation of the estimated distance between each + pattern. It is relevent for MAR data. Parameters ---------- @@ -817,6 +774,9 @@ def frechet_distance_pattern( Second empirical ditribution df_mask : pd.DataFrame Mask indicating on which values the distance has to computed on + method: str + Method used to compute the distance on multivariate datasets with missing values. + Possible values are `robust` and `pattern`. min_n_rows: int Minimum number of rows for a KL estimation @@ -826,6 +786,8 @@ def frechet_distance_pattern( Series of computed metrics """ + if method == "single": + return frechet_distance_base(df1, df2) return pattern_based_weighted_mean_metric( df1, df2, @@ -890,7 +852,7 @@ def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame) -> float: return div_kl -def kl_divergence_pattern( +def kl_divergence( df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame, @@ -913,7 +875,8 @@ def kl_divergence_pattern( df_mask: pd.DataFrame Mask indicating on what values the divergence should be computed method: str - Method used + Method used to compute the divergence on multivariate datasets with missing values. + Possible values are `columnwise` and `gaussian`. min_n_rows: int Minimum number of rows for a KL estimation @@ -1073,12 +1036,13 @@ def get_metric(name: str) -> Callable: "wmape": weighted_mean_absolute_percentage_error, "accuracy": accuracy, "wasserstein_columnwise": dist_wasserstein, - "KL_columnwise": partial(kl_divergence_pattern, method="columnwise"), - "KL_gaussian": partial(kl_divergence_pattern, method="gaussian"), - "ks_test": kolmogorov_smirnov_test, + "KL_columnwise": partial(kl_divergence, method="columnwise"), + "KL_gaussian": partial(kl_divergence, method="gaussian"), + "KS_test": kolmogorov_smirnov_test, "correlation_diff": mean_difference_correlation_matrix_numerical_features, "energy": sum_energy_distances, - "frechet": frechet_distance_pattern, + "frechet_single": partial(frechet_distance, method="single"), + "frechet_pattern": partial(frechet_distance, method="pattern"), "dist_corr_pattern": distance_anticorr_pattern, } return dict_metrics[name] diff --git a/qolmat/utils/utils.py b/qolmat/utils/utils.py index 36771ecc..43433ea9 100644 --- a/qolmat/utils/utils.py +++ b/qolmat/utils/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import warnings import numpy as np @@ -12,6 +12,56 @@ HyperValue = Union[int, float, str] +def _get_numerical_features(df1: pd.DataFrame) -> List[str]: + """Get numerical features from dataframe + + Parameters + ---------- + df1 : pd.DataFrame + + Returns + ------- + List[str] + List of numerical features + + Raises + ------ + Exception + No numerical feature is found + """ + cols_numerical = df1.select_dtypes(include=np.number).columns.tolist() + if len(cols_numerical) == 0: + raise Exception("No numerical feature is found.") + else: + return cols_numerical + + +def _get_categorical_features(df1: pd.DataFrame) -> List[str]: + """Get categorical features from dataframe + + Parameters + ---------- + df1 : pd.DataFrame + + Returns + ------- + List[str] + List of categorical features + + Raises + ------ + Exception + No categorical feature is found + """ + + cols_numerical = df1.select_dtypes(include=np.number).columns.tolist() + cols_categorical = [col for col in df1.columns.to_list() if col not in cols_numerical] + if len(cols_categorical) == 0: + raise Exception("No categorical feature is found.") + else: + return cols_categorical + + def _validate_input(X: NDArray) -> pd.DataFrame: """ Checks that the input X can be converted into a DataFrame, and returns the corresponding diff --git a/tests/benchmark/test_metrics.py b/tests/benchmark/test_metrics.py index b2b1f4b7..a714e81d 100644 --- a/tests/benchmark/test_metrics.py +++ b/tests/benchmark/test_metrics.py @@ -124,16 +124,16 @@ def test_wasserstein_distance(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd. @pytest.mark.parametrize("df2", [df_imputed]) @pytest.mark.parametrize("df_mask", [df_mask]) def test_kl_divergence(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> None: - result = metrics.kl_divergence_pattern(df1, df1, df_mask, method="columnwise") + result = metrics.kl_divergence(df1, df1, df_mask, method="columnwise") expected = pd.Series([0.0, 0.0], index=["col1", "col2"]) pd.testing.assert_series_equal(result, expected, atol=1e-3) - result = metrics.kl_divergence_pattern(df1, df2, df_mask, method="columnwise") + result = metrics.kl_divergence(df1, df2, df_mask, method="columnwise") expected = pd.Series([18.945, 36.637], index=["col1", "col2"]) pd.testing.assert_series_equal(result, expected, atol=1e-3) df_nonan = df1.notna() - result = metrics.kl_divergence_pattern(df1, df2, df_nonan, method="gaussian", min_n_rows=2) + result = metrics.kl_divergence(df1, df2, df_nonan, method="gaussian", min_n_rows=2) expected = pd.Series([1.029], index=["All"]) pd.testing.assert_series_equal(result, expected, atol=1e-3) @@ -154,11 +154,11 @@ def test_kl_divergence_gaussian( @pytest.mark.parametrize("df1", [df_incomplete]) @pytest.mark.parametrize("df2", [df_imputed]) -def test_frechet_distance(df1: pd.DataFrame, df2: pd.DataFrame) -> None: - result = metrics.frechet_distance(df1, df1) +def test_frechet_distance_base(df1: pd.DataFrame, df2: pd.DataFrame) -> None: + result = metrics.frechet_distance_base(df1, df1) np.testing.assert_allclose(result, 0, atol=1e-3) - result = metrics.frechet_distance(df1, df2) + result = metrics.frechet_distance_base(df1, df2) np.testing.assert_allclose(result, 0.134, atol=1e-3) @@ -320,7 +320,7 @@ def test_exception_raise_different_shapes( with pytest.raises(Exception): metrics.mean_difference_correlation_matrix_numerical_features(df1, df2, df_mask) with pytest.raises(Exception): - metrics.frechet_distance(df1, df2) + metrics.frechet_distance_base(df1, df2) @pytest.mark.parametrize("df1", [df_incomplete_cat]) From 397d26f8d0c7e3ddf31dc06c6084d4bd8bb879b5 Mon Sep 17 00:00:00 2001 From: Julien Roussel <3178729-JulienRoussel77@users.noreply.gitlab.com> Date: Mon, 15 Apr 2024 14:30:17 +0200 Subject: [PATCH 3/3] frechet distance refacto --- examples/benchmark.md | 17 +---------------- qolmat/benchmark/metrics.py | 14 ++++++++------ qolmat/imputations/preprocessing.py | 5 ++++- qolmat/utils/utils.py | 4 +--- tests/benchmark/test_metrics.py | 10 +--------- tests/imputations/test_preprocessing.py | 2 ++ 6 files changed, 17 insertions(+), 35 deletions(-) diff --git a/examples/benchmark.md b/examples/benchmark.md index e2fe8c1d..be5a73bf 100644 --- a/examples/benchmark.md +++ b/examples/benchmark.md @@ -16,9 +16,6 @@ jupyter: **This notebook aims to present the Qolmat repo through an example of a multivariate time series. In Qolmat, a few data imputation methods are implemented as well as a way to evaluate their performance.** -```python - -``` First, import some useful librairies @@ -36,26 +33,18 @@ from IPython.display import Image import pandas as pd from datetime import datetime import numpy as np -import scipy import hyperopt as ho -from hyperopt.pyll.base import Apply as hoApply np.random.seed(1234) -import pprint from matplotlib import pyplot as plt -import matplotlib.image as mpimg import matplotlib.ticker as plticker tab10 = plt.get_cmap("tab10") plt.rcParams.update({'font.size': 18}) -from typing import Optional from sklearn.linear_model import LinearRegression -from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, HistGradientBoostingRegressor - -import sys -from qolmat.benchmark import comparator, missing_patterns, hyperparameters +from qolmat.benchmark import comparator, missing_patterns from qolmat.imputations import imputers from qolmat.utils import data, utils, plot @@ -239,10 +228,6 @@ df_plot = data.add_datetime_features(df_plot, col_time="date") dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()} ``` -```python tags=[] -dfs_imputed["VAR_max"].groupby("station").min() -``` - ```python tags=[] station = df_plot.index.get_level_values("station")[0] # station = "Huairou" diff --git a/qolmat/benchmark/metrics.py b/qolmat/benchmark/metrics.py index b68d3e6b..00ca0518 100644 --- a/qolmat/benchmark/metrics.py +++ b/qolmat/benchmark/metrics.py @@ -368,12 +368,12 @@ def total_variance_distance( pd.Series Total variance distance """ - cols_categorical = utils._get_categorical_features(df1) return columnwise_metric( - df1[cols_categorical], - df2[cols_categorical], - df_mask[cols_categorical], + df1, + df2, + df_mask, _total_variance_distance_1D, + type_cols="categorical", ) @@ -792,7 +792,7 @@ def frechet_distance( df1, df2, df_mask, - frechet_distance, + frechet_distance_base, min_n_rows=min_n_rows, type_cols="numerical", ) @@ -1003,10 +1003,12 @@ def pattern_based_weighted_mean_metric( cols = df1.select_dtypes(exclude=["number"]).columns else: raise ValueError(f"Value {type_cols} is not valid for parameter `type_cols`!") + if np.any(df_mask & df1.isna()): raise ValueError("The argument df1 has missing values on the mask!") if np.any(df_mask & df2.isna()): raise ValueError("The argument df2 has missing values on the mask!") + rows_mask = df_mask.any(axis=1) scores = [] weights = [] @@ -1041,7 +1043,7 @@ def get_metric(name: str) -> Callable: "KS_test": kolmogorov_smirnov_test, "correlation_diff": mean_difference_correlation_matrix_numerical_features, "energy": sum_energy_distances, - "frechet_single": partial(frechet_distance, method="single"), + "frechet": partial(frechet_distance, method="single"), "frechet_pattern": partial(frechet_distance, method="pattern"), "dist_corr_pattern": distance_anticorr_pattern, } diff --git a/qolmat/imputations/preprocessing.py b/qolmat/imputations/preprocessing.py index 29d48e58..15ff048e 100644 --- a/qolmat/imputations/preprocessing.py +++ b/qolmat/imputations/preprocessing.py @@ -314,10 +314,13 @@ def make_pipeline_mixte_preprocessing( ohe = OneHotEncoder(handle_unknown="ignore", use_cat_names=True) transformers += [("cat", ohe, selector(dtype_exclude=np.number))] - col_transformer = ColumnTransformer(transformers=transformers).set_output(transform="pandas") + col_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough") + col_transformer = col_transformer.set_output(transform="pandas") preprocessor = Pipeline(steps=[("col_transformer", col_transformer)]) + if avoid_new: preprocessor.steps.append(("bins", BinTransformer())) + print(preprocessor) return preprocessor diff --git a/qolmat/utils/utils.py b/qolmat/utils/utils.py index 43433ea9..ce8f7865 100644 --- a/qolmat/utils/utils.py +++ b/qolmat/utils/utils.py @@ -288,7 +288,7 @@ def get_shape_original(M: NDArray, shape: tuple) -> NDArray: def create_lag_matrices(X: NDArray, p: int) -> Tuple[NDArray, NDArray]: - n_rows, n_cols = X.shape + n_rows, _ = X.shape n_rows_new = n_rows - p list_X_lag = [np.ones((n_rows_new, 1))] for lag in range(p): @@ -304,7 +304,5 @@ def nan_mean_cov(X: NDArray) -> Tuple[NDArray, NDArray]: _, n_variables = X.shape means = np.nanmean(X, axis=0) cov = np.ma.cov(np.ma.masked_invalid(X), rowvar=False).data - print(cov.shape) - print(X.shape) cov = cov.reshape(n_variables, n_variables) return means, cov diff --git a/tests/benchmark/test_metrics.py b/tests/benchmark/test_metrics.py index a714e81d..0c768054 100644 --- a/tests/benchmark/test_metrics.py +++ b/tests/benchmark/test_metrics.py @@ -383,20 +383,12 @@ def test_pattern_based_weighted_mean_metric( def test_pattern_mae_comparison(mocker) -> None: - # def mock_metric(values1: pd.Series, values2: pd.Series) -> float: - # call_count += 1 - # return 0 - mock_metric = mocker.patch("qolmat.benchmark.metrics.accuracy_1D", return_value=0) - # def fun_mean_mae(df_gauss1, df_gauss2, df_mask_gauss) -> float: - # return metrics.mean_squared_error(df_gauss1, df_gauss2, df_mask_gauss).mean() - print(df_mask) df_nonan = df_incomplete.notna() - result = metrics.pattern_based_weighted_mean_metric( + metrics.pattern_based_weighted_mean_metric( df_incomplete, df_imputed, df_nonan, metric=mock_metric, min_n_rows=1 ) - print(result) assert mock_metric.call_count == 2 diff --git a/tests/imputations/test_preprocessing.py b/tests/imputations/test_preprocessing.py index 30b55bd3..5226c332 100644 --- a/tests/imputations/test_preprocessing.py +++ b/tests/imputations/test_preprocessing.py @@ -198,6 +198,8 @@ def test_preprocessing_pipeline(preprocessing_pipeline): # Test with numerical features X_num = pd.DataFrame([[1, 2], [3, 4], [5, 6]]) X_transformed = preprocessing_pipeline.fit_transform(X_num) + print(X_num.shape) + print(X_transformed.shape) assert isinstance(X_transformed, pd.DataFrame) assert X_transformed.shape[1] == X_num.shape[1]