-
Notifications
You must be signed in to change notification settings - Fork 81
/
data_reader.py
executable file
·53 lines (40 loc) · 1.48 KB
/
data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os.path
import numpy as np
import os,sys,copy,time,cv2,pickle,gzip
from scipy.signal import convolve2d
code_dir = os.path.dirname(os.path.realpath(__file__))
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchvision
from PIL import Image
from Utils import *
class DataReader:
def __init__(self,cfg):
self.cfg = cfg
def read_data_by_colorfile(self,color_file,fetch=['nocs_map','xyz_map','normal_map']):
rgb = np.array(Image.open(color_file))
depth = cv2.imread(color_file.replace('rgb','depth'),-1)/1e4
depth[depth<0.1] = 0
depth[depth>self.cfg['zfar']] = 0
seg = cv2.imread(color_file.replace('rgb','seg'),-1).astype(int)
with open(color_file.replace('rgb.png','meta.pkl'),'rb') as ff:
meta = pickle.load(ff)
K = meta['K']
env_body_ids = meta['env_body_ids']
data = {'rgb':rgb,'seg':seg, 'depth':depth, 'color_file':color_file}
if 'nocs_map' in fetch:
data['nocs_map'] = np.array(Image.open(color_file.replace('rgb','nunocs')))
if 'xyz_map' in fetch:
data['xyz_map'] = depth2xyzmap(depth,K)
if 'normal_map' in fetch:
data['normal_map'] = read_normal_image(color_file.replace('rgb','normal'))
for id in env_body_ids:
mask = seg==id
for k in ['rgb','depth','seg','nocs_map','xyz_map','normal_map']:
if k in data:
data[k][mask] = 0
return data