-
Notifications
You must be signed in to change notification settings - Fork 11
/
dataloader.py
83 lines (64 loc) · 2 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import skimage
import torch
from PIL import Image
import torchvision
import torch.nn.functional as F
from io import BytesIO
def demosaic_raw(meas):
tform = skimage.transform.SimilarityTransform(rotation=0.00174)
X = meas.numpy()[0,:,:]
X = X/65535.0
X=X+0.003*np.random.randn(X.shape[0],X.shape[1])
im1=np.zeros((512,640,4))
im1[:,:,0]=X[0::2, 0::2]#b
im1[:,:,1]=X[0::2, 1::2]#gb
im1[:,:,2]=X[1::2, 0::2]#gr
im1[:,:,3]=X[1::2, 1::2]#r
im1=skimage.transform.warp(im1,tform)
im=im1[6:506,10:630,:]
rowMeans = im.mean(axis=1, keepdims=True)
colMeans = im.mean(axis=0, keepdims=True)
allMean = rowMeans.mean()
im = im - rowMeans - colMeans + allMean
im = im.astype('float32')
meas = torch.from_numpy(np.swapaxes(np.swapaxes(im,0,2),1,2)).unsqueeze(0)
return meas[0,:,:,:]
class DatasetFromFilenames:
def __init__(self, filenames_loc_meas, filenames_loc_orig):
self.filenames_meas = filenames_loc_meas
self.paths_meas = get_paths(self.filenames_meas)
self.filenames_orig = filenames_loc_orig
self.paths_orig = get_paths(self.filenames_orig)
self.num_im = len(self.paths_meas)
self.totensor = torchvision.transforms.ToTensor()
self.resize = torchvision.transforms.Resize((256,256))
def __len__(self):
return len(self.paths_meas)
def __getitem__(self, index):
# obtain the image paths
# print(index)
im_path = self.paths_orig[index % self.num_im]
meas_path = self.paths_meas[index % self.num_im]
# load images (grayscale for direct inference)
im = Image.open(im_path)
im = im.convert('RGB')
im = self.resize(im)
# print(im.size)
im = self.totensor(im)
meas = Image.open(meas_path)
meas = self.totensor(meas)
# print(meas.shape)
meas = demosaic_raw(meas)
# print(im_label.max())
# print(torch.max(im_label))
# print(meas.shape)
# print(im.shape)
return meas,im
def get_paths(fname):
paths = []
with open(fname, 'r') as f:
for line in f:
temp = '/media/data/salman/'+str(line).strip()
paths.append(temp)
return paths