-
Notifications
You must be signed in to change notification settings - Fork 1
/
m2trans_nli_filter.py
46 lines (40 loc) · 1.66 KB
/
m2trans_nli_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from M2TransNLI import M2TransNLI
import torch
import pandas as pd
from tqdm import tqdm
import argparse
def softmax(x):
s = torch.exp(output)
return s / torch.sum(s)
def pred(nli_model, premise, hypothesis):
ret = nli_model.predict([premise], [hypothesis])[1][0]
assert(ret == 'entailment' or ret == 'contradiction' or ret == 'neutral')
return ret
def main(args):
model = M2TransNLI.load_model(args.m2trans_nli_model_path)
nli_model = M2TransNLI(model)
ve_output = pd.read_csv(args.input_path)
filtered_doc = []
for doc in tqdm(range(len(ve_output))):
reports = ve_output['Report Impression'][doc].split(args.delimiter)
report = reports[0]
added = 1
for i in range(1, len(reports)):
nli_label = pred(nli_model, report, reports[i])
if nli_label == 'neutral':
report += ' ' + reports[i]
added += 1
if added >= args.topk:
break
filtered_doc.append(report)
ve_output['filtered'] = filtered_doc
ve_output.to_csv(args.save_path, index = False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--m2trans_nli_model_path', default='/n/data1/hms/dbmi/rajpurkar/lab/datasets/nlp/radnli/ifcc-code/resources/model_medrad_19k')
parser.add_argument('--input_path', default='../ALBEF/example.csv', help = 'path to the output of the VE module')
parser.add_argument('--save_path', default='example_m2trans_nli.csv')
parser.add_argument('--delimiter', default = '[SEP]')
parser.add_argument('--topk', type = int, default = 2)
args = parser.parse_args()
main(args)