-
Notifications
You must be signed in to change notification settings - Fork 0
/
tfrecords_writing.py
95 lines (77 loc) · 3.32 KB
/
tfrecords_writing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Author: Zhensheng Wang
Date: 03/16/2022
Data preprocessing and build tfrecords from raw datafiles
Input:
- Zipped X-ray images and training labels
Output:
- By default, 200 tfrecords with batch size of 32 with balanced pneumonia cases
and controls
"""
from zipfile import ZipFile
import pandas as pd
from tqdm import tqdm
# import cv2
import shutil
import os
import numpy as np
# import matplotlib.pyplot as plt
import tensorflow as tf
import argparse
from utils import make_example, process_image, read_images, process_label, parse_example
# from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--width', '-wt', type=int, default=227, help='Resized image width')
parser.add_argument('--height', '-ht', type=int, default=227, help='Resized image height')
parser.add_argument('--data_dir', '-dd', type=str, default='rsna-pneumonia-detection-challenge.zip',
help='Zip file path')
parser.add_argument('--batch_size', '-bs', type=int, default=32, help='Batch size in each tfrecord file')
parser.add_argument('--num_shards', '-ns', type=int, default=200,
help='Number of shards (default is 200, must be less than 945)')
args = parser.parse_args()
WIDTH = args.width
HEIGHT = args.height
BATCH_SIZE = args.batch_size
DATA_DIR = args.data_dir
NUM_SHARDS = args.num_shards
SEED = 42
NUM_CLASSES = 2
NUM_CHANNELS = 3
IMG_SIZE = (WIDTH, HEIGHT, NUM_CHANNELS)
def split_into_tfrecords(img_files, labels, tf_shard):
imgs = list(map(process_image, read_images(img_files, DATA_DIR, (WIDTH, HEIGHT))))
with tf.io.TFRecordWriter(path=f"tfrecords/shard{tf_shard}_{len(img_files)}.tfrecord") as f:
for img, label in zip(imgs, labels):
# label = process_label(label, NUM_CLASSES)
example = make_example(img, label)
f.write(example)
def main():
# sanity check create folder for tfrecords
if 'tfrecords' not in os.listdir():
print('Tfrecords folder not available. Make directory!')
os.mkdir('./tfrecords/')
else:
print('Tfrecords exist! Clean all files...')
shutil.rmtree('./tfrecords')
os.mkdir('./tfrecords')
# Divide postive and control samples into shards
with ZipFile(DATA_DIR, 'r') as archive:
df_meta = pd.read_csv(archive.open('stage_2_train_labels.csv'))
df_meta = df_meta.drop_duplicates('patientId')
# n_folds = np.ceil(len(df_meta) / BATCH_SIZE)
pos = df_meta[df_meta['Target'] == 1].sample(frac=1, random_state=SEED, ignore_index=True)
neg = df_meta[df_meta['Target'] == 0].sample(frac=1, random_state=SEED, ignore_index=True)
t_pos = t_neg = 0
for i, (pos_split, neg_split) in tqdm(enumerate(zip(np.array_split(pos, len(pos) // (BATCH_SIZE // 2)),
np.array_split(neg, len(neg) // (BATCH_SIZE // 2)))),
total=NUM_SHARDS):
if i == NUM_SHARDS: break
df_tmp = pd.concat((pos_split, neg_split), axis=0)
img_files, labels = df_tmp['patientId'].tolist(), df_tmp['Target'].tolist()
t_pos += np.sum(labels)
t_neg += len(labels) - np.sum(labels)
split_into_tfrecords(img_files, labels, i)
print(f'{NUM_SHARDS} tfrecords created!')
print(f'{t_pos} pneumonia and {t_neg} normal samples.')
if __name__ == '__main__':
main()