forked from marian42/butterflies
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rotation_dataset.py
49 lines (40 loc) · 1.95 KB
/
rotation_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import os
import numpy as np
from skimage import io, transform
import csv
import random
import math
from config import *
USE_ALPHA_IMAGES = False
WHITE_THRESHOLD = 0.95
def clip_image(image):
coords = ((image[0, :, :] < WHITE_THRESHOLD) | (image[1, :, :] < WHITE_THRESHOLD) | (image[2, :, :] < WHITE_THRESHOLD)).nonzero()
top_left, _ = torch.min(coords, dim=0)
bottom_right, _ = torch.max(coords, dim=0)
image = image[:, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]
new_size = max(image.shape[1], image.shape[2])
result = torch.zeros((3, new_size, new_size), dtype=torch.float32)
y, x = (new_size - image.shape[1]) // 2, (new_size - image.shape[2]) // 2
result[:, :, :] = 1
result[:, y:y+image.shape[1], x:x+image.shape[2]] = image
return result
class RotationDataset(Dataset):
def __init__(self, return_hashes=False):
file = open('data/rotations.csv', 'r')
reader = csv.reader(file)
self.image_ids, self.angles = zip(*tuple(reader))
self.angles = list(float(angle) for angle in self.angles)
def __len__(self):
return len(self.image_ids)
def __getitem__(self, index):
file_name = ('data/images_alpha/{:s}.png' if USE_ALPHA_IMAGES else 'data/images_128/{:s}.jpg').format(self.image_ids[index])
image = io.imread(file_name)
angle = random.random() * 10 - 5 + random.randint(0, 3) * 90
image = transform.rotate(image[:, :, :3] if USE_ALPHA_IMAGES else image, -self.angles[index] + angle, resize=True, clip=True, mode='constant', cval=1)
image = torch.tensor(image.transpose((2, 0, 1)), dtype=torch.float32)
image = clip_image(image)
image = F.adaptive_avg_pool2d(image, (ROTATION_NETWORK_RESOLUTION, ROTATION_NETWORK_RESOLUTION))
return image, torch.tensor((math.sin(math.radians(angle)), math.cos(math.radians(angle))))