forked from andy-yun/pytorch-0.4-yolov3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
98 lines (84 loc) · 3.56 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/python
# encoding: utf-8
import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from utils import read_truths_args, read_truths
from image import *
def custom_collate(batch):
data = torch.stack([item[0] for item in batch], 0)
targets = torch.stack([item[1] for item in batch], 0)
return data, targets
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, crop=False, jitter=0.3, hue=0.1, saturation=1.5, exposure=1.5, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4):
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
random.shuffle(self.lines)
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
self.crop = crop
self.jitter = jitter
self.hue = hue
self.saturation = saturation
self.exposure = exposure
def __len__(self):
return self.nSamples
def get_different_scale(self):
if self.seen < 4000*self.batch_size:
wh = 13*32 # 416
elif self.seen < 8000*self.batch_size:
wh = (random.randint(0,3) + 13)*32 # 416, 480
elif self.seen < 12000*self.batch_size:
wh = (random.randint(0,5) + 12)*32 # 384, ..., 544
elif self.seen < 16000*self.batch_size:
wh = (random.randint(0,7) + 11)*32 # 352, ..., 576
else: # self.seen < 20000*self.batch_size:
wh = (random.randint(0,9) + 10)*32 # 320, ..., 608
return (wh, wh)
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
if self.train:
if self.seen % (self.batch_size * 10) == 0: # in paper, every 10 batches, but we did every 64 images
self.shape = self.get_different_scale()
img, label = load_data_detection(imgpath, self.shape, self.crop, self.jitter, self.hue, self.saturation, self.exposure)
label = torch.from_numpy(label)
else:
img = Image.open(imgpath).convert('RGB')
if self.shape:
img, org_w, org_h = letterbox_image(img, self.shape[0], self.shape[1]), img.width, img.height
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*5)
#if os.path.getsize(labpath):
#tmp = torch.from_numpy(np.loadtxt(labpath))
try:
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/img.width).astype('float32'))
except Exception:
tmp = torch.zeros(1,5)
#tmp = torch.from_numpy(read_truths(labpath))
tmp = tmp.view(-1)
tsz = tmp.numel()
#print('labpath = %s , tsz = %d' % (labpath, tsz))
if tsz > 50*5:
label = tmp[0:50*5]
elif tsz > 0:
label[0:tsz] = tmp
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
if self.train:
return (img, label)
else:
return (img, label, org_w, org_h)