Skip to content

Commit

Permalink
Cache fix in mondo_score_utils, now faster (#27)
Browse files Browse the repository at this point in the history
* added profiling tool and started what probably is correct caching

* caching mondo to omim in a smart and working way

* Update runner.py

Got rid of profiling tool in runner, forgot to this before

* removed unnecessary comments, fixed docstring giving issues with CI

* edited mondo adapter in docstring

* add import to make docstring work

* add get_adapter in docstring

* add get_adapter method import to make docstring work, added get_adapter in one function in the docstring

* embellishment of docstring
  • Loading branch information
leokim-l authored Jun 21, 2024
1 parent 35fc065 commit a2181fd
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 41 deletions.
40 changes: 35 additions & 5 deletions src/malco/post_process/compute_mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@
import pandas as pd
import pickle as pkl
from malco.post_process.mondo_score_utils import score_grounded_result
from malco.post_process.mondo_score_utils import omim_mappings
from typing import List
from oaklib.interfaces import OboGraphInterface


from oaklib import get_adapter


def mondo_adapter() -> OboGraphInterface:
"""
Get the adapter for the MONDO ontology.
Returns:
Adapter: The adapter.
"""
return get_adapter("sqlite:obo:mondo")

def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
# Read in results TSVs from self.output_dir that match glob results*tsv
#TODO Leo: make more robust, had other results*tsv files from previous testing
# Proposal, go for exact file name match defined somewhere as global/static/immutable
results_data = []
results_files = []
num_ppkt = 0

for subdir, dirs, files in os.walk(output_dir):
for filename in files:
if filename.startswith("result") and filename.endswith(".tsv"):
Expand All @@ -32,6 +46,9 @@ def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
label_to_correct_term = answers.set_index("label")["term"].to_dict()
# Calculate the Mean Reciprocal Rank (MRR) for each file
mrr_scores = []

mondo = mondo_adapter()

for df in results_data:
# For each label in the results file, find if the correct term is ranked
df["rank"] = df.groupby("label")["score"].rank(ascending=False, method="first")
Expand All @@ -42,17 +59,30 @@ def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
# df['correct_term'] is an OMIM
# call OAK and get OMIM IDs for df['term'] and see if df['correct_term'] is one of them
# in the case of phenotypic series, if Mondo corresponds to grouping term, accept it
df['is_correct'] = df.apply(
lambda row: score_grounded_result(row['term'], row['correct_term']) > 0,
axis=1)

# Calculate reciprocal rank
# Make sure caching is used in the following by unwrapping explicitly
results = []
for idx, row in df.iterrows():
val = score_grounded_result(row['term'], row['correct_term'], mondo)
is_correct = val > 0
results.append(is_correct)

df['is_correct'] = results

df["reciprocal_rank"] = df.apply(
lambda row: 1 / row["rank"] if row["is_correct"] else 0, axis=1
)
# Calculate MRR for this file
mrr = df.groupby("label")["reciprocal_rank"].max().mean()
mrr_scores.append(mrr)
print('=' * 100)
print('score_grounded_result cache info:\n')
print(score_grounded_result.cache_info())
print('=' * 100)
print('omim_mappings cache info:\n')
print(omim_mappings.cache_info())
print('=' * 100)

print("MRR scores are:\n")
print(mrr_scores)
Expand Down
63 changes: 29 additions & 34 deletions src/malco/post_process/mondo_score_utils.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,85 @@
from functools import lru_cache
from typing import List

from oaklib import get_adapter
from oaklib.datamodels.vocabulary import IS_A
from oaklib.interfaces import OboGraphInterface, MappingProviderInterface

FULL_SCORE = 1.0
PARTIAL_SCORE = 0.5
from oaklib.interfaces import MappingProviderInterface

from typing import List
from cachetools import cached, LRUCache
from cachetools.keys import hashkey

@lru_cache(maxsize=4096)
def mondo_adapter() -> OboGraphInterface:
"""
Get the adapter for the MONDO ontology.

Returns:
Adapter: The adapter.
"""
return get_adapter("sqlite:obo:mondo")

FULL_SCORE = 1.0
PARTIAL_SCORE = 0.5

@lru_cache(maxsize=1024)
def omim_mappings(term: str) -> List[str]:
@cached(cache=LRUCache(maxsize=16384), info=True, key=lambda term, adapter: hashkey(term))
def omim_mappings(term: str, adapter) -> List[str]:
"""
Get the OMIM mappings for a term.
Example:
>>> omim_mappings("MONDO:0007566")
>>> from oaklib import get_adapter
>>> omim_mappings("MONDO:0007566", get_adapter("sqlite:obo:mondo"))
['OMIM:132800']
Args:
term (str): The term.
adapter: The mondo adapter.
Returns:
str: The OMIM mappings.
"""
adapter = mondo_adapter()
if not isinstance(adapter, MappingProviderInterface):
raise ValueError("Adapter is not an MappingProviderInterface")
"""
omims = []
for m in adapter.sssom_mappings([term], "OMIM"):
for m in adapter.sssom_mappings([term], source="OMIM"):
if m.predicate_id == "skos:exactMatch":
omims.append(m.object_id)
return omims


def score_grounded_result(prediction: str, ground_truth: str) -> float:
@cached(cache=LRUCache(maxsize=4096), info=True, key=lambda prediction, ground_truth, mondo: hashkey(prediction, ground_truth))
def score_grounded_result(prediction: str, ground_truth: str, mondo) -> float:
"""
Score the grounded result.
Exact match:
>>> score_grounded_result("OMIM:132800", "OMIM:132800")
>>> from oaklib import get_adapter
>>> score_grounded_result("OMIM:132800", "OMIM:132800", get_adapter("sqlite:obo:mondo"))
1.0
The predicted Mondo is equivalent to the ground truth OMIM
(via skos:exactMatches in Mondo):
>>> score_grounded_result("MONDO:0007566", "OMIM:132800")
>>> score_grounded_result("MONDO:0007566", "OMIM:132800", get_adapter("sqlite:obo:mondo"))
1.0
The predicted Mondo is a disease entity that groups multiple
OMIMs, one of which is the ground truth:
>>> score_grounded_result("MONDO:0008029", "OMIM:158810")
>>> score_grounded_result("MONDO:0008029", "OMIM:158810", get_adapter("sqlite:obo:mondo"))
0.5
Args:
prediction (str): The prediction.
ground_truth (str): The ground truth.
mondo: The mondo adapter.
Returns:
float: The score.
"""
if not isinstance(mondo, MappingProviderInterface):
raise ValueError("Adapter is not an MappingProviderInterface")

if prediction == ground_truth:
# predication is the correct OMIM
return FULL_SCORE
if ground_truth in omim_mappings(prediction):

#if ground_truth in omim_mappings(prediction, mondo):
if ground_truth in omim_mappings(prediction, mondo):
# prediction is a MONDO that directly maps to a correct OMIM
return FULL_SCORE
mondo = mondo_adapter()

descendants_list = mondo.descendants([prediction], predicates=[IS_A], reflexive=True)
for mondo_descendant in descendants_list:
if ground_truth in omim_mappings(mondo_descendant):
if ground_truth in omim_mappings(mondo_descendant, mondo):
# prediction is a MONDO that maps to a correct OMIM via a descendant
return PARTIAL_SCORE
return 0.0
Expand Down
3 changes: 1 addition & 2 deletions src/malco/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from malco.post_process.generate_plots import make_plots
import os


@dataclass
class MalcoRunner(PhEvalRunner):
input_dir: Path
Expand Down Expand Up @@ -39,12 +38,12 @@ def run(self):
Run the tool to produce the raw output.
"""
print("running with predictor")

run(testdata_dir=self.testdata_dir,
raw_results_dir=self.raw_results_dir,
input_dir=self.input_dir,
langs=self.languages)


def post_process(self,
print_plot=True,
prompts_subdir_name="prompts",
Expand Down

0 comments on commit a2181fd

Please sign in to comment.