-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate_itm_train_file.py
102 lines (83 loc) · 4.09 KB
/
generate_itm_train_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import pandas as pd
from collections import defaultdict
from matplotlib import pyplot as plt
import numpy as np
import json
import queue
import copy
import random
from tqdm import tqdm
import argparse
def main(args):
train_impressions = pd.read_json(args.train_path)
train_impressions = train_impressions.set_index('dicom_id')
dicom_ids = set(train_impressions.index)
train_impressions_chexbert = pd.read_csv(args.train_chexbert_path)
dataset = defaultdict(lambda: queue.Queue())
for idx, row in train_impressions_chexbert.iterrows():
dicom_id, study_id, subject_id = row['dicom_id'], row['study_id'], row['subject_id']
if dicom_id not in dicom_ids:
continue
label = tuple(row[4:])
assert len(label) == 14, 'length of the chexbert label detected is not 14'
dataset[label].put({'dicom_id': dicom_id, 'study_id': study_id, 'sentence': row['Report Impression'], 'image': train_impressions['image'][dicom_id], 'label': None})
total_keys = list(dataset.keys())
total_keys_freq = np.array([dataset[k].qsize() for k in total_keys], dtype = np.float64)
total_keys_freq /= np.sum(total_keys_freq)
train_files = []
edits = np.zeros(14)
for key in tqdm(dataset.keys()):
for i in range(dataset[key].qsize()):
if i % 1000 == 0:
print(i, dataset[key].qsize())
el = dataset[key].get()
dataset[key].put(el)
positive = copy.deepcopy(el)
positive['label'] = 'positive'
train_files.append(positive)
hard_negative = copy.deepcopy(el)
hard_negative['label'] = 'negative'
hard_negative_keys = []
hard_negative_freq = []
for edit_distance in range(1, 14):
for cand in dataset.keys():
if cand == key:
continue
v = np.array(key) - np.array(cand)
if np.sum(np.abs(v)) <= edit_distance:
hard_negative_keys.append(cand)
hard_negative_freq.append(dataset[cand].qsize())
if len(hard_negativel_keys) > 0:
edits[edit_distance] += 1
break
hard_negative_freq = np.array(hard_negative_freq, dtype = np.float64)
hard_negative_freq/=np.sum(hard_negative_freq)
x = np.random.choice(len(hard_negative_keys), 1,p=hard_negative_freq)[0]
hard_negative_keys = hard_negative_keys[x]
hard_negative_cand = dataset[hard_negative_keys].get()
hard_negative['sentence'] = hard_negative_cand['sentence']
hard_negative['hard_negative_dicom_id'] = hard_negative_cand['dicom_id']
hard_negative['edit_distance'] = edit_distance
dataset[hard_negative_key].put(hard_negative_cand)
train_files.append(hard_negative)
negative = copy.deepcopy(el)
negative['label'] = 'negative'
while True:
negative_key = total_keys[np.random.choice(np.arange(len(total_keys)), p=total_keys_freq)]
if negative_key != key:
break
negative_cand = dataset[negative_key].get()
dataset[negative_key].put(negative_cand)
negative['sentence'] = negative_cand['sentence']
negative['negative_sentence_dicom_id'] = negative_cand['dicom_id']
negative['edit_distance'] = np.sum(np.abs(np.array(negative_key) - np.array(key)))
train_files.append(negative)
with open(args.save_path, 'w') as f:
json.dump(train_files, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_path', help='path to the train file (in csv format) used to create the trainining set for image-text matching fine-tuning')
parser.add_argument('--train_chexbert_path', help='path to the chexbert labels (in csv format) for the train file')
parser.add_argument('--save_path', help='path to dump the output' )
args = parser.parse_args()
main(args)