-
Notifications
You must be signed in to change notification settings - Fork 0
/
BrainTumorData.py
57 lines (45 loc) · 2.05 KB
/
BrainTumorData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import glob
import torch
import torchio as tio
class BrainTumorData():
def __init__(self, image_folder,img_type,splits):
super().__init__()
self.image_folder = image_folder
self.img_type = img_type
self.splits=splits
def get_image_lists(self):
image_paths = sorted(glob.glob(self.image_folder+"\\*"+ "\\*"+self.img_type+".nii.gz"))
label_paths = sorted(glob.glob(self.image_folder+"\\*"+ "\\*"+"seg.nii.gz"))
assert len(image_paths) == len(label_paths)
print(len(image_paths))
return image_paths, label_paths
def prepare_data(self):
self.subjects = []
image_paths, label_paths=self.get_image_lists()
for (image_path, label_path) in zip(image_paths, label_paths):
subject = tio.Subject(
mri=tio.ScalarImage(image_path),
brain=tio.LabelMap(label_path),
)
self.subjects.append(subject)
print('Dataset size:', len(self.subjects), 'subjects')
def get_preprocessing_transform(self, type):
if type=="train" or type=="valid":
preprocess = tio.Compose([
tio.ToCanonical(),
tio.CropOrPad((128, 128, 128)),
tio.RemapLabels({4:3}),
tio.OneHot(num_classes=4),
])
else:
preprocess = tio.Compose([
tio.ToCanonical(),
tio.CropOrPad((128, 128, 128)),
tio.RemapLabels({4:3}),
])
return preprocess
def setup(self):
train_subjects, val_subjects, test_subjects = torch.utils.data.random_split(self.subjects, self.splits)
self.train_set = tio.SubjectsDataset(train_subjects, transform=self.get_preprocessing_transform("train"))
self.val_set = tio.SubjectsDataset(val_subjects, transform=self.get_preprocessing_transform("valid"))
self.test_set = tio.SubjectsDataset(test_subjects, transform=self.get_preprocessing_transform("train"))