diff --git a/cytominer_eval/evaluate.py b/cytominer_eval/evaluate.py index 6bd7b80..a8c7137 100644 --- a/cytominer_eval/evaluate.py +++ b/cytominer_eval/evaluate.py @@ -28,6 +28,7 @@ def evaluate( replicate_reproducibility_return_median_cor: bool = False, precision_recall_k: int = 10, grit_control_perts: List[str] = ["None"], + grit_replicate_summary_method: str = "mean", mp_value_params: dict = {}, ): r"""Evaluate profile quality and strength. @@ -86,10 +87,14 @@ def evaluate( Only used when `operation='grit'`. Specific profile identifiers used as a reference when calculating grit. The list entries must be found in the `replicate_groups[replicate_id]` column. + grit_replicate_summary_method : {"mean", "median"}, optional + Only used when `operation='grit'`. Defines how the replicate z scores are + summarized. see + :py:func:`cytominer_eval.operations.util.calculate_grit` mp_value_params : {{}, ...}, optional Only used when `operation='mp_value'`. A key, item pair of optional parameters for calculating mp value. See also - cytominer_eval.operations.util.default_mp_value_parameters + :py:func:`cytominer_eval.operations.util.default_mp_value_parameters` """ # Check replicate groups input check_replicate_groups(eval_metric=operation, replicate_groups=replicate_groups) @@ -124,6 +129,7 @@ def evaluate( control_perts=grit_control_perts, replicate_id=replicate_groups["replicate_id"], group_id=replicate_groups["group_id"], + replicate_summary_method=grit_replicate_summary_method, ) elif operation == "mp_value": metric_result = mp_value( diff --git a/cytominer_eval/operations/grit.py b/cytominer_eval/operations/grit.py index 1c7edce..3b945a3 100644 --- a/cytominer_eval/operations/grit.py +++ b/cytominer_eval/operations/grit.py @@ -7,7 +7,7 @@ import pandas as pd from typing import List -from .util import assign_replicates, calculate_grit +from .util import assign_replicates, calculate_grit, check_grit_replicate_summary_method from cytominer_eval.transform.util import ( set_pair_ids, set_grit_column_info, @@ -20,6 +20,7 @@ def grit( control_perts: List[str], replicate_id: str, group_id: str, + replicate_summary_method: str = "mean", ) -> pd.DataFrame: r"""Calculate grit @@ -30,16 +31,20 @@ def grit( control_perts : list a list of control perturbations to calculate a null distribution replicate_id : str - the metadata identifier marking which column tracks replicate perts + the metadata identifier marking which column tracks unique identifiers group_id : str - the metadata identifier marking which column tracks a higher order groups for - all perturbations + the metadata identifier marking which column defines how replicates are grouped + replicate_summary_method : {'mean', 'median'}, optional + how replicate z-scores to control perts are summarized. Defaults to "mean". Returns ------- pandas.DataFrame A dataframe of grit measurements per perturbation """ + # Check if we support the provided summary method + check_grit_replicate_summary_method(replicate_summary_method) + # Determine pairwise replicates similarity_melted_df = assign_replicates( similarity_melted_df=similarity_melted_df, @@ -61,7 +66,14 @@ def grit( # Calculate grit for each perturbation grit_df = ( similarity_melted_df.groupby(replicate_col_name) - .apply(lambda x: calculate_grit(x, control_perts, column_id_info)) + .apply( + lambda x: calculate_grit( + replicate_group_df=x, + control_perts=control_perts, + column_id_info=column_id_info, + replicate_summary_method=replicate_summary_method, + ) + ) .reset_index(drop=True) ) diff --git a/cytominer_eval/operations/util.py b/cytominer_eval/operations/util.py index 79cc129..b253fcd 100644 --- a/cytominer_eval/operations/util.py +++ b/cytominer_eval/operations/util.py @@ -7,7 +7,10 @@ from sklearn.covariance import EmpiricalCovariance from cytominer_eval.transform import metric_melt -from cytominer_eval.transform.util import set_pair_ids +from cytominer_eval.transform.util import ( + set_pair_ids, + check_grit_replicate_summary_method, +) def assign_replicates( @@ -88,11 +91,17 @@ def calculate_precision_recall(replicate_group_df: pd.DataFrame, k: int) -> pd.S def calculate_grit( - replicate_group_df: pd.DataFrame, control_perts: List[str], column_id_info: dict + replicate_group_df: pd.DataFrame, + control_perts: List[str], + column_id_info: dict, + replicate_summary_method: str = "mean", ) -> pd.Series: """ Usage: Designed to be called within a pandas.DataFrame().groupby().apply() """ + # Confirm that we support the provided summary method + check_grit_replicate_summary_method(replicate_summary_method) + group_entry = get_grit_entry(replicate_group_df, column_id_info["group"]["id"]) pert = get_grit_entry(replicate_group_df, column_id_info["replicate"]["id"]) @@ -125,7 +134,11 @@ def calculate_grit( scaler = StandardScaler() scaler.fit(control_distrib) grit_z_scores = scaler.transform(same_group_distrib) - grit = np.mean(grit_z_scores) + + if replicate_summary_method == "mean": + grit = np.mean(grit_z_scores) + elif replicate_summary_method == "median": + grit = np.median(grit_z_scores) return_bundle = {"perturbation": pert, "group": group_entry, "grit": grit} diff --git a/cytominer_eval/tests/test_evaluate.py b/cytominer_eval/tests/test_evaluate.py index 39575f1..57b8ced 100644 --- a/cytominer_eval/tests/test_evaluate.py +++ b/cytominer_eval/tests/test_evaluate.py @@ -110,11 +110,13 @@ def test_evaluate_replicate_reprod_return_cor_true(): assert top_genes == ["CDK2", "CCNE1", "ATF4", "KIF11", "CCND1"] assert np.round(med_cor_df.similarity_metric.max(), 3) == 0.949 - assert sorted(med_cor_df.columns.tolist()) == sorted([ - "Metadata_gene_name", - "Metadata_pert_name", - "similarity_metric", - ]) + assert sorted(med_cor_df.columns.tolist()) == sorted( + [ + "Metadata_gene_name", + "Metadata_pert_name", + "similarity_metric", + ] + ) def test_evaluate_precision_recall(): @@ -197,6 +199,7 @@ def test_evaluate_grit(): replicate_groups=grit_gene_replicate_groups, operation="grit", grit_control_perts=grit_gene_control_perts, + grit_replicate_summary_method="median", ) top_result = ( @@ -206,7 +209,7 @@ def test_evaluate_grit(): 0, ] ) - assert np.round(top_result.grit, 4) == 2.2597 + assert np.round(top_result.grit, 4) == 2.3352 assert top_result.group == "PTK2" assert top_result.perturbation == "PTK2-2" @@ -224,6 +227,7 @@ def test_evaluate_grit(): replicate_groups=grit_compound_replicate_groups, operation="grit", grit_control_perts=grit_compound_control_perts, + grit_replicate_summary_method="mean", ) top_result = ( diff --git a/cytominer_eval/tests/test_operations/test_grit.py b/cytominer_eval/tests/test_operations/test_grit.py index fb5ee45..80a0163 100644 --- a/cytominer_eval/tests/test_operations/test_grit.py +++ b/cytominer_eval/tests/test_operations/test_grit.py @@ -91,7 +91,7 @@ def test_calculate_grit(): expected_result = {"perturbation": "MTOR-2", "group": "MTOR", "grit": 1.55075} expected_result = pd.DataFrame(expected_result, index=["result"]).transpose() - assert_frame_equal(grit_result, expected_result, check_less_precise=True) + assert_frame_equal(grit_result, expected_result) # Calculate grit will not work with singleton perturbations # (no other perts in same group) @@ -107,7 +107,7 @@ def test_calculate_grit(): expected_result = {"perturbation": "AURKB-2", "group": "AURKB", "grit": np.nan} expected_result = pd.DataFrame(expected_result, index=["result"]).transpose() - assert_frame_equal(grit_result, expected_result, check_less_precise=True) + assert_frame_equal(grit_result, expected_result) # Calculate grit will not work with the full dataframe with pytest.raises(AssertionError) as ae: @@ -147,7 +147,7 @@ def test_grit(): expected_result = {"perturbation": "PTK2-2", "group": "PTK2", "grit": 4.61094} expected_result = pd.DataFrame(expected_result, index=[0]).transpose() - assert_frame_equal(top_result, expected_result, check_less_precise=True) + assert_frame_equal(top_result, expected_result) # There are six singletons in this dataset assert result.grit.isna().sum() == 6 @@ -157,3 +157,39 @@ def test_grit(): # With this data, we do not expect the sum of grit to change assert np.round(result.grit.sum(), 0) == 152.0 + + +def test_grit_summary_metric(): + result = grit( + similarity_melted_df=similarity_melted_df, + control_perts=control_perts, + replicate_id=replicate_id, + group_id=group_id, + replicate_summary_method="median", + ).sort_values(by="grit") + + assert all([x in result.columns for x in ["perturbation", "group", "grit"]]) + + top_result = pd.DataFrame( + result.sort_values(by="grit", ascending=False) + .reset_index(drop=True) + .iloc[0, :], + ) + + expected_result = {"perturbation": "PTK2-2", "group": "PTK2", "grit": 4.715917} + expected_result = pd.DataFrame(expected_result, index=[0]).transpose() + + assert_frame_equal( + top_result, + expected_result, + ) + + with pytest.raises(ValueError) as ve: + output = grit( + similarity_melted_df=similarity_melted_df, + control_perts=control_perts, + replicate_id=replicate_id, + group_id=group_id, + replicate_summary_method="fail", + ) + assert "method not supported, use one of:" in str(ve.value) diff --git a/cytominer_eval/tests/test_transform/test_util.py b/cytominer_eval/tests/test_transform/test_util.py index 6df0d2e..dc93431 100644 --- a/cytominer_eval/tests/test_transform/test_util.py +++ b/cytominer_eval/tests/test_transform/test_util.py @@ -9,6 +9,7 @@ from cytominer_eval.transform.util import ( get_available_eval_metrics, get_available_similarity_metrics, + get_available_grit_summary_methods, get_upper_matrix, convert_pandas_dtypes, assert_pandas_dtypes, @@ -17,6 +18,7 @@ assert_eval_metric, assert_melt, check_replicate_groups, + check_grit_replicate_summary_method, ) random.seed(123) @@ -48,6 +50,11 @@ def test_get_available_similarity_metrics(): assert expected_result == get_available_similarity_metrics() +def test_get_available_grit_summary_methods(): + expected_result = ["mean", "median"] + assert expected_result == get_available_grit_summary_methods() + + def test_assert_eval_metric(): with pytest.raises(AssertionError) as ae: output = assert_eval_metric(eval_metric="NOT SUPPORTED") @@ -159,3 +166,14 @@ def test_check_replicate_groups(): eval_metric="grit", replicate_groups=wrong_group_dict ) assert "replicate_groups for grit not formed properly." in str(ae.value) + + +def test_check_grit_replicate_summary_method(): + + # Pass + for metric in get_available_grit_summary_methods(): + check_grit_replicate_summary_method(metric) + + with pytest.raises(ValueError) as ve: + output = check_grit_replicate_summary_method("fail") + assert "method not supported, use one of:" in str(ve.value) diff --git a/cytominer_eval/transform/util.py b/cytominer_eval/transform/util.py index 53786f4..7c10721 100644 --- a/cytominer_eval/transform/util.py +++ b/cytominer_eval/transform/util.py @@ -13,6 +13,10 @@ def get_available_similarity_metrics(): return ["pearson", "kendall", "spearman"] +def get_available_grit_summary_methods(): + return ["mean", "median"] + + def get_upper_matrix(df: pd.DataFrame) -> np.array: return np.triu(np.ones(df.shape), k=1).astype(bool) @@ -148,3 +152,14 @@ def set_grit_column_info(replicate_id: str, group_id: str) -> dict: column_id_info = {"replicate": replicate_id_info, "group": group_id_info} return column_id_info + + +def check_grit_replicate_summary_method(replicate_summary_method: str): + avail_methods = get_available_grit_summary_methods() + + if replicate_summary_method not in avail_methods: + raise ValueError( + "{input} method not supported, use one of: {avail}".format( + input=replicate_summary_method, avail=avail_methods + ) + )