From 0c793bb1f740e1747a355b951b17b2028d43060c Mon Sep 17 00:00:00 2001 From: Anh Khoa Ngo Ho Date: Wed, 6 Dec 2023 16:14:10 +0100 Subject: [PATCH 1/2] fix: rpca --- qolmat/imputations/imputers.py | 23 ++----- qolmat/imputations/rpca/rpca.py | 80 +++++++++++++++++++++-- qolmat/imputations/rpca/rpca_noisy.py | 51 +++++++++++---- tests/imputations/rpca/test_rpca.py | 5 +- tests/imputations/rpca/test_rpca_noisy.py | 8 +-- tests/imputations/rpca/test_rpca_pcp.py | 10 +-- tests/imputations/test_imputers.py | 3 +- 7 files changed, 134 insertions(+), 46 deletions(-) diff --git a/qolmat/imputations/imputers.py b/qolmat/imputations/imputers.py index fb393221..558526a4 100644 --- a/qolmat/imputations/imputers.py +++ b/qolmat/imputations/imputers.py @@ -1783,13 +1783,9 @@ def _fit_element( model = self.get_model(**hyperparams) X = df.astype(float).values - D = utils.prepare_data(X, model.period) - Omega = ~np.isnan(D) - D = utils.linear_interpolation(D) + model = model.fit_basis(X) - Q = model.fit_basis(D, Omega) - - return Q + return model def _transform_element( self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0 @@ -1821,20 +1817,11 @@ def _transform_element( if self.method not in ["PCP", "noisy"]: raise ValueError("Argument method must be `PCP` or `noisy`!") - hyperparams = self.get_hyperparams(col=col) - model = self.get_model(random_state=self._rng, **hyperparams) + model = self._dict_fitting[col][ngroup] X = df.astype(float).values - D = utils.prepare_data(X, model.period) - Omega = ~np.isnan(D) - D = utils.linear_interpolation(D) - - Q = self._dict_fitting[col][ngroup] - M, A = model.decompose_on_basis(D, Omega, Q) - - M_final = utils.get_shape_original(M, X.shape) - A_final = utils.get_shape_original(A, X.shape) - X_imputed = M_final + A_final + M, A = model.decompose_rpca_signal(X) + X_imputed = M + A df_imputed = pd.DataFrame(X_imputed, index=df.index, columns=df.columns) df_imputed = df.where(~df.isna(), df_imputed) diff --git a/qolmat/imputations/rpca/rpca.py b/qolmat/imputations/rpca/rpca.py index 61fe52bb..f7a9f4ee 100644 --- a/qolmat/imputations/rpca/rpca.py +++ b/qolmat/imputations/rpca/rpca.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Union +from typing import Union, Tuple from typing_extensions import Self import numpy as np @@ -43,18 +43,86 @@ def __init__( self.random_state = random_state self.verbose = verbose - def fit_basis(self, D: NDArray, Omega: NDArray) -> NDArray: + def fit_basis(self, X: NDArray) -> Self: + """Fit RPCA model on data + + Parameters + ---------- + X : NDArray + Observations + + Returns + ------- + Self + Model RPCA + """ + D = utils.prepare_data(X, self.period) + Omega = ~np.isnan(D) + D = utils.linear_interpolation(D) + n_rows, n_cols = D.shape if n_rows == 1 or n_cols == 1: self.V = np.array([[1]]) return self - M, A, L, Q = self.decompose_rpca(D, Omega) - return Q + _, _, _, Q = self.decompose_rpca(D, Omega) + + self.Q = Q + + return self + + def decompose_on_basis( + self, D: NDArray, Omega: NDArray, Q: NDArray + ) -> Tuple[NDArray, NDArray]: + """Decompose data + + Parameters + ---------- + D : NDArray + Observations + Omega : NDArray + Boolean matrix indicating the observed values + Q : NDArray + Learned basis unitary array of shape (rank, n). - def decompose_on_basis(self, D: NDArray, Omega: NDArray, Q: NDArray) -> NDArray: + Returns + ------- + Tuple[NDArray, NDArray] + M : np.ndarray + Low-rank signal matrix of shape (m, n). + A : np.ndarray + Anomalies matrix of shape (m, n). + """ n_rows, n_cols = D.shape if n_rows == 1 or n_cols == 1: return D, np.full_like(D, 0) - M, A, L, Q = self.decompose_rpca(D, Omega) + M, A, _, _ = self.decompose_rpca(D, Omega) return M, A + + def decompose_rpca_signal(self, X: NDArray) -> Tuple[NDArray, NDArray]: + """ + Compute the noisy RPCA with L1 or L2 time penalisation + + Parameters + ---------- + X : NDArray + Observations + + Returns + ------- + M: NDArray + Low-rank signal + A: NDArray + Anomalies + """ + + D = utils.prepare_data(X, self.period) + Omega = ~np.isnan(D) + D = utils.linear_interpolation(D) + + M, A = self.decompose_on_basis(D, Omega, self.Q) + + M_final = utils.get_shape_original(M, X.shape) + A_final = utils.get_shape_original(A, X.shape) + + return M_final, A_final diff --git a/qolmat/imputations/rpca/rpca_noisy.py b/qolmat/imputations/rpca/rpca_noisy.py index 7f29fb84..02a50f11 100644 --- a/qolmat/imputations/rpca/rpca_noisy.py +++ b/qolmat/imputations/rpca/rpca_noisy.py @@ -113,15 +113,17 @@ def get_params_scale(self, D: NDArray) -> Dict[str, float]: } def decompose_on_basis( - self, D: NDArray, Omega: NDArray, Q: NDArray + self, + D: NDArray, + Omega: NDArray, + Q: NDArray, ) -> Tuple[NDArray, NDArray]: - params_scale = self.get_params_scale(D) - lam = params_scale["lam"] if self.lam is None else self.lam - rank = params_scale["rank"] if self.rank is None else self.rank - rank = int(rank) - tau = params_scale["tau"] if self.tau is None else self.tau + lam = self.params_scale["lam"] + # rank = int(self.params_scale["rank"]) + tau = self.params_scale["tau"] + print(self.lam, lam) n_rows, n_cols = D.shape if n_rows == 1 or n_cols == 1: return D, np.full_like(D, 0) @@ -130,6 +132,26 @@ def decompose_on_basis( Ir = np.eye(n_rank) L = np.zeros((n_rows, n_rank)) A = np.zeros((n_rows, n_cols)) + + for _ in range(self.max_iterations): + A_prev = A.copy() + L_prev = L.copy() + L = scp.linalg.solve( + a=2 * tau * Ir + (Q @ Q.T), + b=Q @ (D - A).T, + ).T + A_Omega = rpca_utils.soft_thresholding(D - L @ Q, lam) + A_Omega_C = D - L @ Q + A = np.where(Omega, A_Omega, A_Omega_C) + + Ac = np.linalg.norm(A - A_prev, np.inf) + Lc = np.linalg.norm(L - L_prev, np.inf) + + tolerance = max([Ac, Lc]) # type: ignore # noqa + + if tolerance < self.tol: + break + for i in range(n_rows): d = D[i, :] omega = Omega[i, :] @@ -146,6 +168,7 @@ def decompose_on_basis( ).T L[i, :] = L_row A[i, :] = a + M = L @ Q return M, A @@ -171,12 +194,18 @@ def decompose_rpca( Anomalies """ - params_scale = self.get_params_scale(D) + self.params_scale = self.get_params_scale(D) + + if self.lam is not None: + self.params_scale["lam"] = self.lam + if self.rank is not None: + self.params_scale["rank"] = self.rank + if self.tau is not None: + self.params_scale["tau"] = self.tau - lam = params_scale["lam"] if self.lam is None else self.lam - rank = params_scale["rank"] if self.rank is None else self.rank - rank = int(rank) - tau = params_scale["tau"] if self.tau is None else self.tau + lam = self.params_scale["lam"] + rank = int(self.params_scale["rank"]) + tau = self.params_scale["tau"] mu = 1e-2 if self.mu is None else self.mu n_rows, _ = D.shape diff --git a/tests/imputations/rpca/test_rpca.py b/tests/imputations/rpca/test_rpca.py index bde2b7c7..bde0722b 100644 --- a/tests/imputations/rpca/test_rpca.py +++ b/tests/imputations/rpca/test_rpca.py @@ -24,8 +24,11 @@ class RPCAMock(RPCA): def __init__(self): super().__init__() + self.Q = None - def decompose_rpca(self, D: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]: + def decompose_on_basis( + self, D: NDArray, Omega: NDArray, Q: NDArray + ) -> Tuple[NDArray, NDArray]: self.call_count = 1 return D, D diff --git a/tests/imputations/rpca/test_rpca_noisy.py b/tests/imputations/rpca/test_rpca_noisy.py index f39d8ea8..8e19c76c 100644 --- a/tests/imputations/rpca/test_rpca_noisy.py +++ b/tests/imputations/rpca/test_rpca_noisy.py @@ -93,7 +93,7 @@ def test_rpca_noisy_get_params_scale(X: NDArray): @pytest.mark.parametrize("norm", ["L2"]) def test_rpca_decompose_rpca_signal_shape(norm: str): """Test RPCA noisy results if tau and lambda equal zero.""" - rpca = RPCANoisy(rank=2, norm=norm) + rpca = RPCANoisy(rank=2, norm=norm).fit_basis(X_test) X_result, A_result = rpca.decompose_rpca_signal(X_test) assert X_result.shape == X_test.shape assert A_result.shape == X_test.shape @@ -102,7 +102,7 @@ def test_rpca_decompose_rpca_signal_shape(norm: str): @pytest.mark.parametrize("X, X_interpolated", [(X_incomplete, X_interpolated)]) def test_rpca_noisy_zero_tau_zero_lambda(X: NDArray, X_interpolated: NDArray): """Test RPCA noisy results if tau and lambda equal zero.""" - rpca = RPCANoisy(tau=0, lam=0, norm="L2") + rpca = RPCANoisy(tau=0, lam=0, norm="L2").fit_basis(X) X_result, A_result = rpca.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X_interpolated, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -114,7 +114,7 @@ def test_rpca_noisy_zero_tau_zero_lambda(X: NDArray, X_interpolated: NDArray): ) def test_rpca_noisy_zero_tau(X: NDArray, lam: float, X_interpolated: NDArray): """Test RPCA noisy results if tau equals zero.""" - rpca = RPCANoisy(tau=0, lam=lam, norm="L2") + rpca = RPCANoisy(tau=0, lam=lam, norm="L2").fit_basis(X) X_result, A_result = rpca.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X_interpolated, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -126,7 +126,7 @@ def test_rpca_noisy_zero_tau(X: NDArray, lam: float, X_interpolated: NDArray): ) def test_rpca_noisy_zero_lambda(X: NDArray, tau: float, X_interpolated: NDArray): """Test RPCA noisy results if lambda equals zero.""" - rpca = RPCANoisy(tau=tau, lam=0, norm="L2") + rpca = RPCANoisy(tau=tau, lam=0, norm="L2").fit_basis(X) X_result, A_result = rpca.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, np.full_like(X, 0), atol=1e-4) np.testing.assert_allclose(A_result, X_interpolated, atol=1e-4) diff --git a/tests/imputations/rpca/test_rpca_pcp.py b/tests/imputations/rpca/test_rpca_pcp.py index 0cdc2c94..91df1d9b 100644 --- a/tests/imputations/rpca/test_rpca_pcp.py +++ b/tests/imputations/rpca/test_rpca_pcp.py @@ -75,7 +75,7 @@ def test_check_cost_function_minimized_no_warning( @pytest.mark.parametrize("X", [X_complete]) def test_rpca_rpca_pcp_get_params_scale(X: NDArray): """Test the parameters are well scaled.""" - rpca_pcp = RPCAPCP(max_iterations=max_iterations, mu=0.5, lam=0.1) + rpca_pcp = RPCAPCP(max_iterations=max_iterations, mu=0.5, lam=0.1).fit_basis(X) result_dict = rpca_pcp.get_params_scale(X) result = list(result_dict.values()) params_expected = [1 / 7, np.sqrt(2) / 2] @@ -88,7 +88,7 @@ def test_rpca_rpca_pcp_zero_lambda_small_mu(X: NDArray, mu: float): The problem is ill-conditioned and the result depends on the parameter mu; case when mu is small. """ - rpca_pcp = RPCAPCP(lam=0, mu=mu) + rpca_pcp = RPCAPCP(lam=0, mu=mu).fit_basis(X) X_result, A_result = rpca_pcp.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, np.full_like(X, 0), atol=1e-4) np.testing.assert_allclose(A_result, X, atol=1e-4) @@ -100,7 +100,7 @@ def test_rpca_rpca_pcp_zero_lambda_large_mu(X: NDArray, mu: float): The problem is ill-conditioned and the result depends on the parameter mu; case when mu is large. """ - rpca_pcp = RPCAPCP(lam=0, mu=mu) + rpca_pcp = RPCAPCP(lam=0, mu=mu).fit_basis(X) X_result, A_result = rpca_pcp.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -109,7 +109,7 @@ def test_rpca_rpca_pcp_zero_lambda_large_mu(X: NDArray, mu: float): @pytest.mark.parametrize("X, mu", [(X_complete, large_mu)]) def test_rpca_rpca_pcp_large_lambda_small_mu(X: NDArray, mu: float): """Test RPCA PCP results with large lambda and small mu.""" - rpca_pcp = RPCAPCP(lam=1e3, mu=mu) + rpca_pcp = RPCAPCP(lam=1e3, mu=mu).fit_basis(X) X_result, A_result = rpca_pcp.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -121,7 +121,7 @@ def test_rpca_temporal_signal(synthetic_temporal_data): signal = synthetic_temporal_data period = 100 lam = 0.1 - rpca = RPCAPCP(period=period, lam=lam, mu=0.01) + rpca = RPCAPCP(period=period, lam=lam, mu=0.01).fit_basis(signal) X_result, A_result = rpca.decompose_rpca_signal(signal) X_input_rpca = utils.linear_interpolation(signal.reshape(period, -1)) assert np.linalg.norm(X_input_rpca, "nuc") >= np.linalg.norm( diff --git a/tests/imputations/test_imputers.py b/tests/imputations/test_imputers.py index 694288c2..6b7b7bd4 100644 --- a/tests/imputations/test_imputers.py +++ b/tests/imputations/test_imputers.py @@ -263,7 +263,8 @@ def test_ImputerRegressor_fit_transform(df: pd.DataFrame) -> None: @pytest.mark.parametrize("df", [df_timeseries]) def test_ImputerRPCA_fit_transform(df: pd.DataFrame) -> None: imputer = imputers.ImputerRPCA(columnwise=False, max_iterations=100, tau=1, lam=0.3) - result = imputer.fit_transform(df) + imputer = imputer.fit(df) + result = imputer.transform(df) expected = pd.DataFrame( { "col1": [i for i in range(20)], From c779ed15b67875b4d313706e1883aa390706b3f2 Mon Sep 17 00:00:00 2001 From: Anh Khoa Ngo Ho Date: Thu, 7 Dec 2023 14:15:26 +0100 Subject: [PATCH 2/2] fix: full_matrices=False and unit tests --- qolmat/imputations/imputers.py | 3 +- qolmat/imputations/rpca/rpca.py | 35 ++++++++++++++++++++++- qolmat/imputations/rpca/rpca_noisy.py | 21 +------------- qolmat/imputations/rpca/rpca_utils.py | 2 +- tests/imputations/rpca/test_rpca.py | 11 +++++++ tests/imputations/rpca/test_rpca_noisy.py | 8 +++--- tests/imputations/rpca/test_rpca_pcp.py | 8 +++--- tests/imputations/test_imputers.py | 10 +++++++ 8 files changed, 66 insertions(+), 32 deletions(-) diff --git a/qolmat/imputations/imputers.py b/qolmat/imputations/imputers.py index 558526a4..9f2ae1ec 100644 --- a/qolmat/imputations/imputers.py +++ b/qolmat/imputations/imputers.py @@ -1820,8 +1820,7 @@ def _transform_element( model = self._dict_fitting[col][ngroup] X = df.astype(float).values - M, A = model.decompose_rpca_signal(X) - X_imputed = M + A + X_imputed = model.transform_with_basis(X) df_imputed = pd.DataFrame(X_imputed, index=df.index, columns=df.columns) df_imputed = df.where(~df.isna(), df_imputed) diff --git a/qolmat/imputations/rpca/rpca.py b/qolmat/imputations/rpca/rpca.py index f7a9f4ee..026ecaa9 100644 --- a/qolmat/imputations/rpca/rpca.py +++ b/qolmat/imputations/rpca/rpca.py @@ -99,6 +99,36 @@ def decompose_on_basis( M, A, _, _ = self.decompose_rpca(D, Omega) return M, A + def transform_with_basis(self, X: NDArray) -> NDArray: + """ + Compute the noisy RPCA with L1 or L2 time penalisation + + Parameters + ---------- + X : NDArray + Observations + + Returns + ------- + X_final: NDArray + M + A + """ + + D = utils.prepare_data(X, self.period) + Omega = ~np.isnan(D) + D = utils.linear_interpolation(D) + n_rows, n_cols = D.shape + if n_rows == 1 or n_cols == 1: + return D + + M, A = self.decompose_on_basis(D, Omega, self.Q) + + M_final = utils.get_shape_original(M, X.shape) + A_final = utils.get_shape_original(A, X.shape) + + X_final = M_final + A_final + return X_final + def decompose_rpca_signal(self, X: NDArray) -> Tuple[NDArray, NDArray]: """ Compute the noisy RPCA with L1 or L2 time penalisation @@ -119,8 +149,11 @@ def decompose_rpca_signal(self, X: NDArray) -> Tuple[NDArray, NDArray]: D = utils.prepare_data(X, self.period) Omega = ~np.isnan(D) D = utils.linear_interpolation(D) + n_rows, n_cols = D.shape + if n_rows == 1 or n_cols == 1: + return D, np.full_like(D, 0) - M, A = self.decompose_on_basis(D, Omega, self.Q) + M, A, _, _ = self.decompose_rpca(D, Omega) M_final = utils.get_shape_original(M, X.shape) A_final = utils.get_shape_original(A, X.shape) diff --git a/qolmat/imputations/rpca/rpca_noisy.py b/qolmat/imputations/rpca/rpca_noisy.py index 02a50f11..96570510 100644 --- a/qolmat/imputations/rpca/rpca_noisy.py +++ b/qolmat/imputations/rpca/rpca_noisy.py @@ -120,10 +120,8 @@ def decompose_on_basis( ) -> Tuple[NDArray, NDArray]: lam = self.params_scale["lam"] - # rank = int(self.params_scale["rank"]) tau = self.params_scale["tau"] - print(self.lam, lam) n_rows, n_cols = D.shape if n_rows == 1 or n_cols == 1: return D, np.full_like(D, 0) @@ -152,23 +150,6 @@ def decompose_on_basis( if tolerance < self.tol: break - for i in range(n_rows): - d = D[i, :] - omega = Omega[i, :] - L_row = np.zeros((1, n_rank)) - a = np.full_like(d, 0) - for _ in range(self.max_iterations): - a_omega = rpca_utils.soft_thresholding(d - L_row @ Q, lam) - a_omega_C = d - L_row @ Q - a = np.where(omega, a_omega, a_omega_C) - - L_row = scp.linalg.solve( - a=2 * tau * Ir + (Q @ Q.T), - b=Q @ (d - a).T, - ).T - L[i, :] = L_row - A[i, :] = a - M = L @ Q return M, A @@ -357,7 +338,7 @@ def decompose_rpca_algorithm( Y = np.zeros((n_rows, n_cols)) X = D.copy() A = np.zeros((n_rows, n_cols)) - U, S, Vt = np.linalg.svd(X) + U, S, Vt = np.linalg.svd(X, full_matrices=False) U = U[:, :rank] S = S[:rank] diff --git a/qolmat/imputations/rpca/rpca_utils.py b/qolmat/imputations/rpca/rpca_utils.py index c84b8aa0..ea7bb603 100644 --- a/qolmat/imputations/rpca/rpca_utils.py +++ b/qolmat/imputations/rpca/rpca_utils.py @@ -31,7 +31,7 @@ def approx_rank( """ if threshold == 1: return min(M.shape) - _, values_singular, _ = np.linalg.svd(M, full_matrices=True) + _, values_singular, _ = np.linalg.svd(M, full_matrices=False) cum_sum = np.cumsum(values_singular) / np.sum(values_singular) rank = np.argwhere(cum_sum > threshold)[0][0] + 1 diff --git a/tests/imputations/rpca/test_rpca.py b/tests/imputations/rpca/test_rpca.py index bde0722b..cded4c70 100644 --- a/tests/imputations/rpca/test_rpca.py +++ b/tests/imputations/rpca/test_rpca.py @@ -26,6 +26,10 @@ def __init__(self): super().__init__() self.Q = None + def decompose_rpca(self, D: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray, None, None]: + self.call_count = 1 + return D, D, None, None + def decompose_on_basis( self, D: NDArray, Omega: NDArray, Q: NDArray ) -> Tuple[NDArray, NDArray]: @@ -42,3 +46,10 @@ def test_rpca_decompose_rpca_signal() -> None: assert M.shape == X_incomplete.shape assert A.shape == X_incomplete.shape assert rpca.call_count == 1 + + +def test_transform_with_basis() -> None: + rpca = RPCAMock() + X_imputed = rpca.transform_with_basis(X_incomplete) + assert X_imputed.shape == X_incomplete.shape + assert rpca.call_count == 1 diff --git a/tests/imputations/rpca/test_rpca_noisy.py b/tests/imputations/rpca/test_rpca_noisy.py index 8e19c76c..f39d8ea8 100644 --- a/tests/imputations/rpca/test_rpca_noisy.py +++ b/tests/imputations/rpca/test_rpca_noisy.py @@ -93,7 +93,7 @@ def test_rpca_noisy_get_params_scale(X: NDArray): @pytest.mark.parametrize("norm", ["L2"]) def test_rpca_decompose_rpca_signal_shape(norm: str): """Test RPCA noisy results if tau and lambda equal zero.""" - rpca = RPCANoisy(rank=2, norm=norm).fit_basis(X_test) + rpca = RPCANoisy(rank=2, norm=norm) X_result, A_result = rpca.decompose_rpca_signal(X_test) assert X_result.shape == X_test.shape assert A_result.shape == X_test.shape @@ -102,7 +102,7 @@ def test_rpca_decompose_rpca_signal_shape(norm: str): @pytest.mark.parametrize("X, X_interpolated", [(X_incomplete, X_interpolated)]) def test_rpca_noisy_zero_tau_zero_lambda(X: NDArray, X_interpolated: NDArray): """Test RPCA noisy results if tau and lambda equal zero.""" - rpca = RPCANoisy(tau=0, lam=0, norm="L2").fit_basis(X) + rpca = RPCANoisy(tau=0, lam=0, norm="L2") X_result, A_result = rpca.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X_interpolated, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -114,7 +114,7 @@ def test_rpca_noisy_zero_tau_zero_lambda(X: NDArray, X_interpolated: NDArray): ) def test_rpca_noisy_zero_tau(X: NDArray, lam: float, X_interpolated: NDArray): """Test RPCA noisy results if tau equals zero.""" - rpca = RPCANoisy(tau=0, lam=lam, norm="L2").fit_basis(X) + rpca = RPCANoisy(tau=0, lam=lam, norm="L2") X_result, A_result = rpca.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X_interpolated, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -126,7 +126,7 @@ def test_rpca_noisy_zero_tau(X: NDArray, lam: float, X_interpolated: NDArray): ) def test_rpca_noisy_zero_lambda(X: NDArray, tau: float, X_interpolated: NDArray): """Test RPCA noisy results if lambda equals zero.""" - rpca = RPCANoisy(tau=tau, lam=0, norm="L2").fit_basis(X) + rpca = RPCANoisy(tau=tau, lam=0, norm="L2") X_result, A_result = rpca.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, np.full_like(X, 0), atol=1e-4) np.testing.assert_allclose(A_result, X_interpolated, atol=1e-4) diff --git a/tests/imputations/rpca/test_rpca_pcp.py b/tests/imputations/rpca/test_rpca_pcp.py index 91df1d9b..2b5267a2 100644 --- a/tests/imputations/rpca/test_rpca_pcp.py +++ b/tests/imputations/rpca/test_rpca_pcp.py @@ -75,7 +75,7 @@ def test_check_cost_function_minimized_no_warning( @pytest.mark.parametrize("X", [X_complete]) def test_rpca_rpca_pcp_get_params_scale(X: NDArray): """Test the parameters are well scaled.""" - rpca_pcp = RPCAPCP(max_iterations=max_iterations, mu=0.5, lam=0.1).fit_basis(X) + rpca_pcp = RPCAPCP(max_iterations=max_iterations, mu=0.5, lam=0.1) result_dict = rpca_pcp.get_params_scale(X) result = list(result_dict.values()) params_expected = [1 / 7, np.sqrt(2) / 2] @@ -88,7 +88,7 @@ def test_rpca_rpca_pcp_zero_lambda_small_mu(X: NDArray, mu: float): The problem is ill-conditioned and the result depends on the parameter mu; case when mu is small. """ - rpca_pcp = RPCAPCP(lam=0, mu=mu).fit_basis(X) + rpca_pcp = RPCAPCP(lam=0, mu=mu) X_result, A_result = rpca_pcp.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, np.full_like(X, 0), atol=1e-4) np.testing.assert_allclose(A_result, X, atol=1e-4) @@ -100,7 +100,7 @@ def test_rpca_rpca_pcp_zero_lambda_large_mu(X: NDArray, mu: float): The problem is ill-conditioned and the result depends on the parameter mu; case when mu is large. """ - rpca_pcp = RPCAPCP(lam=0, mu=mu).fit_basis(X) + rpca_pcp = RPCAPCP(lam=0, mu=mu) X_result, A_result = rpca_pcp.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) @@ -109,7 +109,7 @@ def test_rpca_rpca_pcp_zero_lambda_large_mu(X: NDArray, mu: float): @pytest.mark.parametrize("X, mu", [(X_complete, large_mu)]) def test_rpca_rpca_pcp_large_lambda_small_mu(X: NDArray, mu: float): """Test RPCA PCP results with large lambda and small mu.""" - rpca_pcp = RPCAPCP(lam=1e3, mu=mu).fit_basis(X) + rpca_pcp = RPCAPCP(lam=1e3, mu=mu) X_result, A_result = rpca_pcp.decompose_rpca_signal(X) np.testing.assert_allclose(X_result, X, atol=1e-4) np.testing.assert_allclose(A_result, np.full_like(X, 0), atol=1e-4) diff --git a/tests/imputations/test_imputers.py b/tests/imputations/test_imputers.py index 6b7b7bd4..d6b405bf 100644 --- a/tests/imputations/test_imputers.py +++ b/tests/imputations/test_imputers.py @@ -274,6 +274,16 @@ def test_ImputerRPCA_fit_transform(df: pd.DataFrame) -> None: result = np.around(result) np.testing.assert_allclose(result, expected, atol=1e-2) + result = imputer.transform(df.iloc[:10]) + expected = pd.DataFrame( + { + "col1": [i for i in range(10)], + "col2": [0, 1, 2, 2, 2] + [i for i in range(5, 10)], + } + ) + result = np.around(result) + np.testing.assert_allclose(result, expected, atol=1e-2) + @pytest.mark.parametrize("df", [df_incomplete]) def test_ImputerSoftImpute_fit_transform(df: pd.DataFrame) -> None: