-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from monarch-initiative/post_process_add_mondo…
…_utils Post process add mondo utils
- Loading branch information
Showing
19 changed files
with
1,389 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import os | ||
import csv | ||
from pathlib import Path | ||
import pandas as pd | ||
import pickle as pkl | ||
from malco.post_process.mondo_score_utils import score_grounded_result | ||
|
||
|
||
def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path: | ||
# Read in results TSVs from self.output_dir that match glob results*tsv | ||
#TODO Leo: make more robust, had other results*tsv files from previous testing | ||
# Proposal, go for exact file name match defined somewhere as global/static/immutable | ||
results_data = [] | ||
results_files = [] | ||
num_ppkt = 0 | ||
for subdir, dirs, files in os.walk(output_dir): | ||
for filename in files: | ||
if filename.startswith("result") and filename.endswith(".tsv"): | ||
file_path = os.path.join(subdir, filename) | ||
df = pd.read_csv(file_path, sep="\t") | ||
num_ppkt = df["label"].nunique() | ||
results_data.append(df) | ||
# Append both the subdirectory relative to output_dir and the filename | ||
results_files.append(os.path.relpath(file_path, output_dir)) | ||
# Read in correct answers from prompt_dir | ||
answers_path = os.path.join(os.getcwd(), prompt_dir, correct_answer_file) | ||
answers = pd.read_csv( | ||
answers_path, sep="\t", header=None, names=["description", "term", "label"] | ||
) | ||
|
||
# Mapping each label to its correct term | ||
label_to_correct_term = answers.set_index("label")["term"].to_dict() | ||
# Calculate the Mean Reciprocal Rank (MRR) for each file | ||
mrr_scores = [] | ||
for df in results_data: | ||
# For each label in the results file, find if the correct term is ranked | ||
df["rank"] = df.groupby("label")["score"].rank(ascending=False, method="first") | ||
label_4_non_eng = df["label"].str.replace("_[a-z][a-z]-prompt", "_en-prompt", regex=True) | ||
df["correct_term"] = label_4_non_eng.map(label_to_correct_term) | ||
|
||
# df['term'] is Mondo or OMIM ID, or even disease label | ||
# df['correct_term'] is an OMIM | ||
# call OAK and get OMIM IDs for df['term'] and see if df['correct_term'] is one of them | ||
# in the case of phenotypic series, if Mondo corresponds to grouping term, accept it | ||
df['is_correct'] = df.apply( | ||
lambda row: score_grounded_result(row['term'], row['correct_term']) > 0, | ||
axis=1) | ||
|
||
# Calculate reciprocal rank | ||
df["reciprocal_rank"] = df.apply( | ||
lambda row: 1 / row["rank"] if row["is_correct"] else 0, axis=1 | ||
) | ||
# Calculate MRR for this file | ||
mrr = df.groupby("label")["reciprocal_rank"].max().mean() | ||
mrr_scores.append(mrr) | ||
|
||
print("MRR scores are:\n") | ||
print(mrr_scores) | ||
plot_dir = output_dir / "plots" | ||
plot_dir.mkdir(exist_ok=True) | ||
plot_data_file = plot_dir / "plotting_data.tsv" | ||
|
||
# write out results for plotting | ||
with plot_data_file.open('w', newline = '') as dat: | ||
writer = csv.writer(dat, quoting = csv.QUOTE_NONNUMERIC, delimiter = '\t', lineterminator='\n') | ||
writer.writerow(results_files) | ||
writer.writerow(mrr_scores) | ||
return plot_data_file, plot_dir, num_ppkt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
import os | ||
import csv | ||
|
||
# Make a nice plot, use it as function or as script | ||
|
||
def make_plots(plot_data_file, plot_dir, languages, num_ppkt): | ||
with plot_data_file.open('r', newline = '') as f: | ||
lines = csv.reader(f, quoting = csv.QUOTE_NONNUMERIC, delimiter = '\t', lineterminator='\n') | ||
results_files = next(lines) | ||
mrr_scores = next(lines) | ||
#lines = f.read().splitlines() | ||
|
||
print(results_files) | ||
print(mrr_scores) | ||
|
||
# Plotting the results | ||
sns.barplot(x = results_files, y = mrr_scores) | ||
plt.xlabel("Results File") | ||
plt.ylabel("Mean Reciprocal Rank (MRR)") | ||
plt.title("MRR of Correct Answers Across Different Results Files") | ||
plot_path = plot_dir / (str(len(languages)) + "_langs_" + str(num_ppkt) + "ppkt.png") | ||
plt.savefig(plot_path) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,22 @@ | ||
from pathlib import Path | ||
|
||
from malco.post_process.post_process_results_format import create_standardised_results | ||
import os | ||
|
||
|
||
def post_process(raw_results_dir: Path, output_dir: Path) -> None: | ||
def post_process(raw_results_dir: Path, output_dir: Path, langs: tuple) -> None: | ||
""" | ||
Post-process the raw results output to standardised PhEval TSV format. | ||
Args: | ||
raw_results_dir (Path): Path to the raw results directory. | ||
output_dir (Path): Path to the output directory. | ||
""" | ||
create_standardised_results(raw_results_dir=raw_results_dir, output_dir=output_dir) | ||
|
||
for lang in langs: | ||
raw_results_lang = raw_results_dir / lang | ||
output_lang = output_dir / lang | ||
raw_results_lang.mkdir(exist_ok=True) | ||
output_lang.mkdir(exist_ok=True) | ||
|
||
create_standardised_results(raw_results_dir=raw_results_lang, | ||
output_dir=output_lang, output_file_name = "results.tsv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import zipfile | ||
import os | ||
import requests | ||
|
||
phenopacket_zip_url="https://github.com/monarch-initiative/phenopacket-store/releases/download/0.1.11/all_phenopackets.zip" | ||
# TODO just point to a folder w/ ppkts | ||
phenopacket_dir="phenopacket-store" | ||
|
||
def setup_phenopackets(self) -> str: | ||
phenopacket_store_path = os.path.join(self.input_dir, phenopacket_dir) | ||
if os.path.exists(phenopacket_store_path): | ||
print(f"{phenopacket_store_path} exists, skipping download.") | ||
else: | ||
print(f"{phenopacket_store_path} doesn't exist, downloading phenopackets...") | ||
download_phenopackets(self, phenopacket_zip_url, phenopacket_dir) | ||
return phenopacket_store_path | ||
|
||
|
||
def download_phenopackets(self, phenopacket_zip_url, phenopacket_dir): | ||
# Ensure the directory for storing the phenopackets exists | ||
phenopacket_store_path = os.path.join(self.input_dir, phenopacket_dir) | ||
os.makedirs(phenopacket_store_path, exist_ok=True) | ||
|
||
# Download the phenopacket release zip file | ||
response = requests.get(phenopacket_zip_url) | ||
zip_path = os.path.join(self.input_dir, "all_phenopackets.zip") | ||
with open(zip_path, "wb") as f: | ||
f.write(response.content) | ||
print("Download completed.") | ||
|
||
# Unzip the phenopacket release zip file | ||
with zipfile.ZipFile(zip_path, "r") as zip_ref: | ||
zip_ref.extractall(phenopacket_store_path) | ||
print("Unzip completed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.