-
Notifications
You must be signed in to change notification settings - Fork 38
/
visualize_dataset.py
65 lines (56 loc) · 1.94 KB
/
visualize_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from re import L
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import math
import torchvision
class lmdbDataset(Dataset):
def __init__(self, root=None):
self.env = lmdb.open(
root,
max_readers=16,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot creat lmdb from %s' % (root))
sys.exit(0)
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples
self.transform = transforms.ToTensor()
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key.encode())
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('RGB')
img = self.transform(img)
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
# # data augmentation
label_key = 'label-%09d' % index
label = str(txn.get(label_key.encode()).decode('utf-8'))
return img
if __name__ =="__main__":
dataset_path = './data_CVPR2021/training/label/real/11.ReCTS'
dataset = lmdbDataset(dataset_path)
loader = DataLoader(dataset, batch_size=1,shuffle=True)
for i, batch in enumerate(loader):
img = batch
torchvision.utils.save_image(img,'ReCTS'+str(i)+'.jpg')
if i >=2 :
break