-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
135 lines (116 loc) · 5.72 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse, os
import pandas as pd
import scanpy as sc
import torch
import pyro
from collections import OrderedDict
from functools import reduce
from inferelator.postprocessing.model_metrics import CombinedMetric as CM
def log_normal_to_mean_and_var(mean_log: torch.tensor, std_log: torch.tensor, num_samples:int):
mean_calculated = torch.exp(mean_log + (std_log ** 2) / 2)
var_calculated = (torch.exp(std_log ** 2) - 1) * torch.exp(2 * mean_log + std_log ** 2)
return mean_calculated, var_calculated
def logit_normal_to_mean_and_var(mean_logit: torch.tensor, std_logit: torch.tensor, num_samples: int):
logit_variable = torch.distributions.Normal(mean_logit, std_logit).sample((num_samples,))
variable = torch.sigmoid(logit_variable)
var, mean = torch.var_mean(variable, 0)
return mean, var
def get_U_means_dfs(exp_dirs, num_sampling_iters=100):
all_U_means = []
for dname in exp_dirs:
pyro.get_param_store().clear()
pyro.get_param_store().load(os.path.join(dname, "best_iter.params"), map_location="cpu")
U_means, U_vars = log_normal_to_mean_and_var(
pyro.get_param_store()['U_means'].detach(),
pyro.get_param_store()['U_stds'].detach(),
num_sampling_iters
)
print("Shape of U_means:", U_means.shape)
pyro.get_param_store().clear()
U_obs = pd.read_csv(os.path.join(dname, "U_obs_names.csv"), sep=",", header=None)
U_vars = pd.read_csv(os.path.join(dname, "U_var_names.csv"), sep=",", header=None)
print("Shape of U_means after transformation:", U_means.shape)
U_means_df = pd.DataFrame(data=U_means.cpu().numpy(), columns=U_vars[0], index=U_obs[0])
all_U_means.append(U_means_df)
combined_U_means = reduce(lambda a, b: a.add(b, fill_value=0), all_U_means) / len(all_U_means)
combined_U_means_df = pd.DataFrame(
data=combined_U_means,
columns=U_vars[0],
index=U_obs[0]
)
U_means_dfs = OrderedDict()
for i, dname in enumerate(exp_dirs):
U_means_dfs[dname] = all_U_means[i]
U_means_dfs['combined'] = combined_U_means_df
return U_means_dfs
def get_A_means_dfs(exp_dirs, num_sampling_iters=100):
all_A_means = []
for dname in exp_dirs:
pyro.get_param_store().clear()
pyro.get_param_store().load(os.path.join(dname, "best_iter.params"), map_location="cpu")
A_means, A_vars = logit_normal_to_mean_and_var(
pyro.get_param_store()['A_means'].detach(),
pyro.get_param_store()['A_stds'].detach(),
num_sampling_iters
)
pyro.get_param_store().clear()
V_obs = pd.read_csv(os.path.join(dname, "V_obs_names.csv"), sep=',', header=None)
V_vars = pd.read_csv(os.path.join(dname, "V_var_names.csv"), sep=',', header=None)
A_means_df = pd.DataFrame(data=A_means.cpu().numpy(), columns=V_vars[0], index=V_obs[0])
all_A_means.append(A_means_df)
combined_A_means = reduce(lambda a, b: a.add(b, fill_value=0), all_A_means)/len(all_A_means)
combined_A_means_df = pd.DataFrame(
data=combined_A_means,
columns=V_vars[0],
index=V_obs[0]
)
A_means_dfs = OrderedDict()
for i, dname in enumerate(exp_dirs):
A_means_dfs[dname] = all_A_means[i]
A_means_dfs['combined'] = combined_A_means_df
return A_means_dfs
def get_separate_and_combined_auprcs_A(A_means_dfs, gold_standard_path, filter_method='keep_all_gold_standard'):
# gold_standard and params file should have axes: genes x tfs
auprc_results = OrderedDict()
gold_standard = sc.read_csv(gold_standard_path, delimiter="\t", first_column_names=True).to_df()
for name, A_means_df in A_means_dfs.items():
metrics = CM([A_means_df], gold_standard, filter_method)
auprc_results[name] = metrics.aupr
return auprc_results
def main():
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
A_means_dfs = get_A_means_dfs(args.exp_dirs, args.num_sampling_iters)
for i, (name, df) in enumerate(A_means_dfs.items()):
if name == "combined":
A_means_dfs[name].to_csv(os.path.join(args.output_dir, "inferred_consensus_grn.tsv"), sep="\t")
else:
A_means_dfs[name].to_csv(
os.path.join(args.output_dir, "inferred_{}_grn.tsv".format(args.expression_names[i])), sep="\t"
)
U_means_dfs = get_U_means_dfs(args.exp_dirs, args.num_sampling_iters)
for i, (name, df) in enumerate(U_means_dfs.items()):
if name == "combined":
U_means_dfs[name].to_csv(os.path.join(args.output_dir, "inferred_consensus_tfa.tsv"), sep="\t")
else:
U_means_dfs[name].to_csv(
os.path.join(args.output_dir, "inferred_{}_tfa.tsv".format(args.expression_names[i])), sep="\t"
)
if args.gold_standard_path is not None:
auprcs = get_separate_and_combined_auprcs_A(A_means_dfs, args.gold_standard_path)
with open(os.path.join(args.output_dir, "auprcs.txt"), 'w') as f:
for i, (name, auprc) in enumerate(auprcs.items()):
if name == "combined":
f.write("consensus: " + str(auprc) + "\n")
else:
f.write(args.expression_names[i] + ": " + str(auprc) + "\n")
return auprcs
parser = argparse.ArgumentParser()
parser.add_argument("--exp-dirs", nargs="+")
parser.add_argument("--expression-names", nargs="+", help="must be in the same order as for --exp-dirs")
parser.add_argument("--gold-standard-path", default=None)
parser.add_argument("--num-sampling-iters", type=int, default=500)
parser.add_argument("--output-dir")
if __name__ == "__main__":
main()