Skip to content

Commit

Permalink
add and test median grit summary
Browse files Browse the repository at this point in the history
  • Loading branch information
gwaybio committed Feb 11, 2021
1 parent 32a1491 commit e78c2a5
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 18 deletions.
8 changes: 7 additions & 1 deletion cytominer_eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def evaluate(
replicate_reproducibility_return_median_cor: bool = False,
precision_recall_k: int = 10,
grit_control_perts: List[str] = ["None"],
grit_replicate_summary_method: str = "mean",
mp_value_params: dict = {},
):
r"""Evaluate profile quality and strength.
Expand Down Expand Up @@ -86,10 +87,14 @@ def evaluate(
Only used when `operation='grit'`. Specific profile identifiers used as a
reference when calculating grit. The list entries must be found in the
`replicate_groups[replicate_id]` column.
grit_replicate_summary_method : {"mean", "median"}, optional
Only used when `operation='grit'`. Defines how the replicate z scores are
summarized. see
:py:func:`cytominer_eval.operations.util.calculate_grit`
mp_value_params : {{}, ...}, optional
Only used when `operation='mp_value'`. A key, item pair of optional parameters
for calculating mp value. See also
cytominer_eval.operations.util.default_mp_value_parameters
:py:func:`cytominer_eval.operations.util.default_mp_value_parameters`
"""
# Check replicate groups input
check_replicate_groups(eval_metric=operation, replicate_groups=replicate_groups)
Expand Down Expand Up @@ -124,6 +129,7 @@ def evaluate(
control_perts=grit_control_perts,
replicate_id=replicate_groups["replicate_id"],
group_id=replicate_groups["group_id"],
replicate_summary_method=grit_replicate_summary_method,
)
elif operation == "mp_value":
metric_result = mp_value(
Expand Down
22 changes: 17 additions & 5 deletions cytominer_eval/operations/grit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
from typing import List

from .util import assign_replicates, calculate_grit
from .util import assign_replicates, calculate_grit, check_grit_replicate_summary_method
from cytominer_eval.transform.util import (
set_pair_ids,
set_grit_column_info,
Expand All @@ -20,6 +20,7 @@ def grit(
control_perts: List[str],
replicate_id: str,
group_id: str,
replicate_summary_method: str = "mean",
) -> pd.DataFrame:
r"""Calculate grit
Expand All @@ -30,16 +31,20 @@ def grit(
control_perts : list
a list of control perturbations to calculate a null distribution
replicate_id : str
the metadata identifier marking which column tracks replicate perts
the metadata identifier marking which column tracks unique identifiers
group_id : str
the metadata identifier marking which column tracks a higher order groups for
all perturbations
the metadata identifier marking which column defines how replicates are grouped
replicate_summary_method : {'mean', 'median'}, optional
how replicate z-scores to control perts are summarized. Defaults to "mean".
Returns
-------
pandas.DataFrame
A dataframe of grit measurements per perturbation
"""
# Check if we support the provided summary method
check_grit_replicate_summary_method(replicate_summary_method)

# Determine pairwise replicates
similarity_melted_df = assign_replicates(
similarity_melted_df=similarity_melted_df,
Expand All @@ -61,7 +66,14 @@ def grit(
# Calculate grit for each perturbation
grit_df = (
similarity_melted_df.groupby(replicate_col_name)
.apply(lambda x: calculate_grit(x, control_perts, column_id_info))
.apply(
lambda x: calculate_grit(
replicate_group_df=x,
control_perts=control_perts,
column_id_info=column_id_info,
replicate_summary_method=replicate_summary_method,
)
)
.reset_index(drop=True)
)

Expand Down
19 changes: 16 additions & 3 deletions cytominer_eval/operations/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from sklearn.covariance import EmpiricalCovariance

from cytominer_eval.transform import metric_melt
from cytominer_eval.transform.util import set_pair_ids
from cytominer_eval.transform.util import (
set_pair_ids,
check_grit_replicate_summary_method,
)


def assign_replicates(
Expand Down Expand Up @@ -88,11 +91,17 @@ def calculate_precision_recall(replicate_group_df: pd.DataFrame, k: int) -> pd.S


def calculate_grit(
replicate_group_df: pd.DataFrame, control_perts: List[str], column_id_info: dict
replicate_group_df: pd.DataFrame,
control_perts: List[str],
column_id_info: dict,
replicate_summary_method: str = "mean",
) -> pd.Series:
"""
Usage: Designed to be called within a pandas.DataFrame().groupby().apply()
"""
# Confirm that we support the provided summary method
check_grit_replicate_summary_method(replicate_summary_method)

group_entry = get_grit_entry(replicate_group_df, column_id_info["group"]["id"])
pert = get_grit_entry(replicate_group_df, column_id_info["replicate"]["id"])

Expand Down Expand Up @@ -125,7 +134,11 @@ def calculate_grit(
scaler = StandardScaler()
scaler.fit(control_distrib)
grit_z_scores = scaler.transform(same_group_distrib)
grit = np.mean(grit_z_scores)

if replicate_summary_method == "mean":
grit = np.mean(grit_z_scores)
elif replicate_summary_method == "median":
grit = np.median(grit_z_scores)

return_bundle = {"perturbation": pert, "group": group_entry, "grit": grit}

Expand Down
16 changes: 10 additions & 6 deletions cytominer_eval/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ def test_evaluate_replicate_reprod_return_cor_true():
assert top_genes == ["CDK2", "CCNE1", "ATF4", "KIF11", "CCND1"]

assert np.round(med_cor_df.similarity_metric.max(), 3) == 0.949
assert sorted(med_cor_df.columns.tolist()) == sorted([
"Metadata_gene_name",
"Metadata_pert_name",
"similarity_metric",
])
assert sorted(med_cor_df.columns.tolist()) == sorted(
[
"Metadata_gene_name",
"Metadata_pert_name",
"similarity_metric",
]
)


def test_evaluate_precision_recall():
Expand Down Expand Up @@ -197,6 +199,7 @@ def test_evaluate_grit():
replicate_groups=grit_gene_replicate_groups,
operation="grit",
grit_control_perts=grit_gene_control_perts,
grit_replicate_summary_method="median",
)

top_result = (
Expand All @@ -206,7 +209,7 @@ def test_evaluate_grit():
0,
]
)
assert np.round(top_result.grit, 4) == 2.2597
assert np.round(top_result.grit, 4) == 2.3352
assert top_result.group == "PTK2"
assert top_result.perturbation == "PTK2-2"

Expand All @@ -224,6 +227,7 @@ def test_evaluate_grit():
replicate_groups=grit_compound_replicate_groups,
operation="grit",
grit_control_perts=grit_compound_control_perts,
grit_replicate_summary_method="mean",
)

top_result = (
Expand Down
42 changes: 39 additions & 3 deletions cytominer_eval/tests/test_operations/test_grit.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_calculate_grit():
expected_result = {"perturbation": "MTOR-2", "group": "MTOR", "grit": 1.55075}
expected_result = pd.DataFrame(expected_result, index=["result"]).transpose()

assert_frame_equal(grit_result, expected_result, check_less_precise=True)
assert_frame_equal(grit_result, expected_result)

# Calculate grit will not work with singleton perturbations
# (no other perts in same group)
Expand All @@ -107,7 +107,7 @@ def test_calculate_grit():
expected_result = {"perturbation": "AURKB-2", "group": "AURKB", "grit": np.nan}
expected_result = pd.DataFrame(expected_result, index=["result"]).transpose()

assert_frame_equal(grit_result, expected_result, check_less_precise=True)
assert_frame_equal(grit_result, expected_result)

# Calculate grit will not work with the full dataframe
with pytest.raises(AssertionError) as ae:
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_grit():
expected_result = {"perturbation": "PTK2-2", "group": "PTK2", "grit": 4.61094}
expected_result = pd.DataFrame(expected_result, index=[0]).transpose()

assert_frame_equal(top_result, expected_result, check_less_precise=True)
assert_frame_equal(top_result, expected_result)

# There are six singletons in this dataset
assert result.grit.isna().sum() == 6
Expand All @@ -157,3 +157,39 @@ def test_grit():

# With this data, we do not expect the sum of grit to change
assert np.round(result.grit.sum(), 0) == 152.0


def test_grit_summary_metric():
result = grit(
similarity_melted_df=similarity_melted_df,
control_perts=control_perts,
replicate_id=replicate_id,
group_id=group_id,
replicate_summary_method="median",
).sort_values(by="grit")

assert all([x in result.columns for x in ["perturbation", "group", "grit"]])

top_result = pd.DataFrame(
result.sort_values(by="grit", ascending=False)
.reset_index(drop=True)
.iloc[0, :],
)

expected_result = {"perturbation": "PTK2-2", "group": "PTK2", "grit": 4.715917}
expected_result = pd.DataFrame(expected_result, index=[0]).transpose()

assert_frame_equal(
top_result,
expected_result,
)

with pytest.raises(ValueError) as ve:
output = grit(
similarity_melted_df=similarity_melted_df,
control_perts=control_perts,
replicate_id=replicate_id,
group_id=group_id,
replicate_summary_method="fail",
)
assert "method not supported, use one of:" in str(ve.value)
18 changes: 18 additions & 0 deletions cytominer_eval/tests/test_transform/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cytominer_eval.transform.util import (
get_available_eval_metrics,
get_available_similarity_metrics,
get_available_grit_summary_methods,
get_upper_matrix,
convert_pandas_dtypes,
assert_pandas_dtypes,
Expand All @@ -17,6 +18,7 @@
assert_eval_metric,
assert_melt,
check_replicate_groups,
check_grit_replicate_summary_method,
)

random.seed(123)
Expand Down Expand Up @@ -48,6 +50,11 @@ def test_get_available_similarity_metrics():
assert expected_result == get_available_similarity_metrics()


def test_get_available_grit_summary_methods():
expected_result = ["mean", "median"]
assert expected_result == get_available_grit_summary_methods()


def test_assert_eval_metric():
with pytest.raises(AssertionError) as ae:
output = assert_eval_metric(eval_metric="NOT SUPPORTED")
Expand Down Expand Up @@ -159,3 +166,14 @@ def test_check_replicate_groups():
eval_metric="grit", replicate_groups=wrong_group_dict
)
assert "replicate_groups for grit not formed properly." in str(ae.value)


def test_check_grit_replicate_summary_method():

# Pass
for metric in get_available_grit_summary_methods():
check_grit_replicate_summary_method(metric)

with pytest.raises(ValueError) as ve:
output = check_grit_replicate_summary_method("fail")
assert "method not supported, use one of:" in str(ve.value)
15 changes: 15 additions & 0 deletions cytominer_eval/transform/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def get_available_similarity_metrics():
return ["pearson", "kendall", "spearman"]


def get_available_grit_summary_methods():
return ["mean", "median"]


def get_upper_matrix(df: pd.DataFrame) -> np.array:
return np.triu(np.ones(df.shape), k=1).astype(bool)

Expand Down Expand Up @@ -148,3 +152,14 @@ def set_grit_column_info(replicate_id: str, group_id: str) -> dict:

column_id_info = {"replicate": replicate_id_info, "group": group_id_info}
return column_id_info


def check_grit_replicate_summary_method(replicate_summary_method: str):
avail_methods = get_available_grit_summary_methods()

if replicate_summary_method not in avail_methods:
raise ValueError(
"{input} method not supported, use one of: {avail}".format(
input=replicate_summary_method, avail=avail_methods
)
)

0 comments on commit e78c2a5

Please sign in to comment.