-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_util.py
91 lines (69 loc) · 2.61 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pickle
import random
import numpy as np
from torch.utils.data import Dataset
import augly.image as imaugs
import PIL.Image as Image
import torch
import torchvision
from transformers import AutoTokenizer
class street_dataset(Dataset):
def __init__(self, sv_root_dir):
self.sv_root_dir = sv_root_dir
with open('data/sv_list.pkl', 'rb') as f:
self.st_nid_list = pickle.load(f)
self.pic2id = np.load('data/pic2id.npy', allow_pickle='TRUE').item()
self.transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
]
)
self.transform = self.transform
def __len__(self):
return len(self.st_nid_list)
def __getitem__(self, street_idx):
streetview_list = []
sv_id = []
if len(self.st_nid_list[street_idx]) > 16:
img_list = random.sample(self.st_nid_list[street_idx], 16)
else:
img_list = self.st_nid_list[street_idx]
for img_file_name in img_list:
sv_id.append(self.pic2id[img_file_name])
image = Image.open(self.sv_root_dir + img_file_name)
if self.transform:
image = self.transform(image)
streetview_list.append(image)
length = len(streetview_list)
streetview = torch.stack(streetview_list, 0)
return streetview, length, street_idx, sv_id
class region_dataset_test(Dataset):
def __init__(self, sv_root_dir):
# self.street_idx = street_idx
self.sv_root_dir = sv_root_dir
with open('data/sv_list.pkl', 'rb') as f:
self.st_nid_list = pickle.load(f)
self.pic2id = np.load('data/pic2id.npy', allow_pickle='TRUE').item()
self.transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
]
)
self.transform = self.transform
def __len__(self):
return len(self.st_nid_list)
def __getitem__(self, street_idx):
streetview_list = []
sv_id = []
# scn_list = []
img_list = self.st_nid_list[street_idx]
for img_file_name in img_list:
sv_id.append(self.pic2id[img_file_name])
image = Image.open(self.sv_root_dir + img_file_name)
if self.transform:
image = self.transform(image)
streetview_list.append(image)
streetview = torch.stack(streetview_list, 0)
return streetview, street_idx, sv_id