-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataset.py
93 lines (81 loc) · 3.31 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import cv2
import os
import torch
import numpy as np
import random
from torch.utils.data import Dataset
cv2.setNumThreads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VimeoDataset(Dataset):
def __init__(self, dataset_name, path, batch_size=32, model="RIFE"):
self.batch_size = batch_size
self.dataset_name = dataset_name
self.model = model
self.h = 256
self.w = 448
self.data_root = path
self.image_root = os.path.join(self.data_root, 'sequences')
train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
with open(train_fn, 'r') as f:
self.trainlist = f.read().splitlines()
with open(test_fn, 'r') as f:
self.testlist = f.read().splitlines()
self.load_data()
def __len__(self):
return len(self.meta_data)
def load_data(self):
if self.dataset_name != 'test':
self.meta_data = self.trainlist
else:
self.meta_data = self.testlist
def aug(self, img0, gt, img1, h, w):
ih, iw, _ = img0.shape
x = np.random.randint(0, ih - h + 1)
y = np.random.randint(0, iw - w + 1)
img0 = img0[x:x + h, y:y + w, :]
img1 = img1[x:x + h, y:y + w, :]
gt = gt[x:x + h, y:y + w, :]
return img0, gt, img1
def getimg(self, index):
imgpath = os.path.join(self.image_root, self.meta_data[index])
imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
img0 = cv2.imread(imgpaths[0])
gt = cv2.imread(imgpaths[1])
img1 = cv2.imread(imgpaths[2])
return img0, gt, img1
def __getitem__(self, index):
img0, gt, img1 = self.getimg(index)
if 'train' in self.dataset_name:
img0, gt, img1 = self.aug(img0, gt, img1, 256, 256)
if random.uniform(0, 1) < 0.5:
img0 = img0[:, :, ::-1]
img1 = img1[:, :, ::-1]
gt = gt[:, :, ::-1]
if random.uniform(0, 1) < 0.5:
img1, img0 = img0, img1
if random.uniform(0, 1) < 0.5:
img0 = img0[::-1]
img1 = img1[::-1]
gt = gt[::-1]
if random.uniform(0, 1) < 0.5:
img0 = img0[:, ::-1]
img1 = img1[:, ::-1]
gt = gt[:, ::-1]
p = random.uniform(0, 1)
if p < 0.25:
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
elif p < 0.5:
img0 = cv2.rotate(img0, cv2.ROTATE_180)
gt = cv2.rotate(gt, cv2.ROTATE_180)
img1 = cv2.rotate(img1, cv2.ROTATE_180)
elif p < 0.75:
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
return torch.cat((img0, img1, gt), 0)