Skip to content

Commit

Permalink
add dhcp dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
helenass97 committed Jul 18, 2022
1 parent 12ac0cc commit 0f4ac69
Show file tree
Hide file tree
Showing 5 changed files with 2,851 additions and 7 deletions.
131 changes: 131 additions & 0 deletions dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import torchvision.transforms as transforms
from skimage.transform import resize
from biobank_dataloader import *
from dhcp_dataloader import * # using DHCP_2D right now
import torchvision
import SimpleITK as sitk
import random
import numpy as np
from PIL import Image


Expand Down Expand Up @@ -396,3 +398,132 @@ def init_biobank_age_dataloader(opt, shuffle_test=False):

return healthy_dataloader_train, healthy_dataloader_val, healthy_dataloader_test, \
anomaly_dataloader_train, anomaly_dataloader_val, anomaly_dataloader_test


def init_dhcp_dataloader_2d(opt, shuffle_test=False):
'''
Initialize both datasets and dataloaders
image_size = [128, 160]
'''
if (not opt.aug_rician_noise == None) or (not opt.aug_bspline_deformation == None) or (not opt.resize_image == None):
transforms = []
else:
transforms = None

if opt.resize_image:
transforms.append(ResizeImage(image_size=opt.resize_size))

if opt.aug_rician_noise:
transforms.append(RicianNoise(noise_level=opt.aug_rician_noise))

if opt.aug_bspline_deformation:
transforms.append(ElasticDeformationsBspline(num_controlpoints=opt.aug_bspline_deformation[0], sigma=opt.aug_bspline_deformation[1]))

if opt.aug_rician_noise or opt.aug_bspline_deformation or opt.resize_image:
transforms = torchvision.transforms.Compose(transforms)

healthy_train = DHCP_2D(image_path=opt.dataroot,
label_path=opt.label_path,
num_classes=2,
task='regression',
class_label=0,
transform=transforms)

anomaly_train = DHCP_2D(image_path=opt.dataroot,
label_path=opt.label_path,
num_classes=2,
task='regression',
class_label=1,
transform=transforms)

healthy_dataloader_train, healthy_dataloader_val, healthy_dataloader_test = train_val_test_split(healthy_train, val_split=0.1, test_split=0.1,
random_seed=opt.random_seed)
anomaly_dataloader_train, anomaly_dataloader_val, anomaly_dataloader_test = train_val_test_split(anomaly_train, val_split=0.1, test_split=0.1,
random_seed=opt.random_seed)


print('Train healthy data length: ', len(healthy_dataloader_train), 'Val data length: ',len(healthy_dataloader_val), 'Test data length: ', len(healthy_dataloader_test))
print('Train anomaly data length: ', len(anomaly_dataloader_train), 'Val data length: ',len(anomaly_dataloader_val), 'Test data length: ', len(anomaly_dataloader_test))

healthy_dataloader_train = torch.utils.data.DataLoader(healthy_dataloader_train, batch_size=opt.batch_size//2,
shuffle=True)
anomaly_dataloader_train = torch.utils.data.DataLoader(anomaly_dataloader_train, batch_size=opt.batch_size//2,
shuffle=True)

healthy_dataloader_val = torch.utils.data.DataLoader(healthy_dataloader_val, batch_size=opt.batch_size//2,
shuffle=True)
anomaly_dataloader_val = torch.utils.data.DataLoader(anomaly_dataloader_val, batch_size=opt.batch_size//2,
shuffle=True)
healthy_dataloader_test = torch.utils.data.DataLoader(healthy_dataloader_test, batch_size=opt.batch_size//2,
shuffle=shuffle_test)
anomaly_dataloader_test = torch.utils.data.DataLoader(anomaly_dataloader_test, batch_size=opt.batch_size//2,
shuffle=shuffle_test)

return healthy_dataloader_train, healthy_dataloader_val, healthy_dataloader_test, anomaly_dataloader_train, anomaly_dataloader_val, anomaly_dataloader_test



# def init_dhcp_dataloader_2d_reg(opt, shuffle_test=False):
# '''
# Initialize both datasets and dataloaders
# image_size = [128, 160]
# '''
# if (not opt.aug_rician_noise == None) or (not opt.aug_bspline_deformation == None) or (not opt.resize_image == None):
# transforms = []
# else:
# transforms = None

# if opt.resize_image:
# transforms.append(ResizeImage(image_size=opt.resize_size))

# if opt.aug_rician_noise:
# transforms.append(RicianNoise(noise_level=opt.aug_rician_noise))

# if opt.aug_bspline_deformation:
# transforms.append(ElasticDeformationsBspline(num_controlpoints=opt.aug_bspline_deformation[0], sigma=opt.aug_bspline_deformation[1]))

# if opt.aug_rician_noise or opt.aug_bspline_deformation or opt.resize_image:
# transforms = torchvision.transforms.Compose(transforms)

# healthy_train = DHCP_2D(image_path=opt.dataroot,
# label_path=opt.labels, #was label_path before
# num_classes=2,
# task='regression',
# class_label=0,
# #get_id=False,
# transform=transforms)

# anomaly_train = DHCP_2D(image_path=opt.dataroot,
# label_path=opt.labels, #was label_path before
# num_classes=2,
# task='regression',
# class_label=1,
# #get_id = False,
# transform=transforms)

# healthy_dataloader_train, healthy_dataloader_val, healthy_dataloader_test = train_val_test_split(healthy_train, val_split=0.1, test_split=0.1,
# random_seed=opt.random_seed)
# anomaly_dataloader_train, anomaly_dataloader_val, anomaly_dataloader_test = train_val_test_split(anomaly_train, val_split=0.1, test_split=0.1,
# random_seed=opt.random_seed)


# print('Train term data length: ', len(healthy_dataloader_train), 'Val data length: ',len(healthy_dataloader_val), 'Test data length: ', len(healthy_dataloader_test))
# print('Train preterm data length: ', len(anomaly_dataloader_train), 'Val data length: ',len(anomaly_dataloader_val), 'Test data length: ', len(anomaly_dataloader_test))

# healthy_dataloader_train = torch.utils.data.DataLoader(healthy_dataloader_train, batch_size=opt.batch_size//2,
# shuffle=True)
# anomaly_dataloader_train = torch.utils.data.DataLoader(anomaly_dataloader_train, batch_size=opt.batch_size//2,
# shuffle=True)

# healthy_dataloader_val = torch.utils.data.DataLoader(healthy_dataloader_val, batch_size=opt.batch_size//2,
# shuffle=True)
# anomaly_dataloader_val = torch.utils.data.DataLoader(anomaly_dataloader_val, batch_size=opt.batch_size//2,
# shuffle=True)
# healthy_dataloader_test = torch.utils.data.DataLoader(healthy_dataloader_test, batch_size=opt.batch_size//2,
# shuffle=shuffle_test)
# anomaly_dataloader_test = torch.utils.data.DataLoader(anomaly_dataloader_test, batch_size=opt.batch_size//2,
# shuffle=shuffle_test)

# return healthy_dataloader_train, healthy_dataloader_val, healthy_dataloader_test, anomaly_dataloader_train, anomaly_dataloader_val, anomaly_dataloader_test


Loading

0 comments on commit 0f4ac69

Please sign in to comment.