-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMesh_dataset.py
142 lines (117 loc) · 5.88 KB
/
Mesh_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from torch.utils.data import Dataset
import pandas as pd
import torch
import numpy as np
from vedo import *
from scipy.spatial import distance_matrix
class Mesh_Dataset(Dataset):
def __init__(self, data_list_path, num_classes=15, patch_size=7000):
"""
Args:
h5_path (string): Path to the txt file with h5 files.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data_list = pd.read_csv(data_list_path, header=None)
self.num_classes = num_classes
self.patch_size = patch_size
def __len__(self):
return self.data_list.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
i_mesh = self.data_list.iloc[idx][0] #vtk file name
# read vtk
mesh = load(i_mesh)
labels = mesh.getCellArray('Label').astype('int32').reshape(-1, 1)
#create one-hot map
# label_map = np.zeros([mesh.cells.shape[0], self.num_classes], dtype='int32')
# label_map = np.eye(self.num_classes)[labels]
# label_map = label_map.reshape([len(labels), self.num_classes])
# move mesh to origin
N = mesh.NCells()
points = vtk2numpy(mesh.polydata().GetPoints().GetData())
ids = vtk2numpy(mesh.polydata().GetPolys().GetData()).reshape((N, -1))[:,1:]
cells = points[ids].reshape(N, 9).astype(dtype='float32')
mean_cell_centers = mesh.centerOfMass()
cells[:, 0:3] -= mean_cell_centers[0:3]
cells[:, 3:6] -= mean_cell_centers[0:3]
cells[:, 6:9] -= mean_cell_centers[0:3]
# customized normal calculation; the vtk/vedo build-in function will change number of points
v1 = np.zeros([mesh.NCells(), 3], dtype='float32')
v2 = np.zeros([mesh.NCells(), 3], dtype='float32')
v1[:, 0] = cells[:, 0] - cells[:, 3]
v1[:, 1] = cells[:, 1] - cells[:, 4]
v1[:, 2] = cells[:, 2] - cells[:, 5]
v2[:, 0] = cells[:, 3] - cells[:, 6]
v2[:, 1] = cells[:, 4] - cells[:, 7]
v2[:, 2] = cells[:, 5] - cells[:, 8]
mesh_normals = np.cross(v1, v2)
mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
mesh_normals[:, 0] /= mesh_normal_length[:]
mesh_normals[:, 1] /= mesh_normal_length[:]
mesh_normals[:, 2] /= mesh_normal_length[:]
mesh.addCellArray(mesh_normals, 'Normal')
# preprae input and make copies of original data
points = mesh.points().copy()
points[:, 0:3] -= mean_cell_centers[0:3]
normals = mesh.getCellArray('Normal').copy() # need to copy, they use the same memory address
barycenters = mesh.cellCenters() # don't need to copy
barycenters -= mean_cell_centers[0:3]
#normalized data
maxs = points.max(axis=0)
mins = points.min(axis=0)
means = points.mean(axis=0)
stds = points.std(axis=0)
nmeans = normals.mean(axis=0)
nstds = normals.std(axis=0)
for i in range(3):
cells[:, i] = (cells[:, i] - means[i]) / stds[i] #point 1
cells[:, i+3] = (cells[:, i+3] - means[i]) / stds[i] #point 2
cells[:, i+6] = (cells[:, i+6] - means[i]) / stds[i] #point 3
barycenters[:,i] = (barycenters[:,i] - mins[i]) / (maxs[i]-mins[i])
normals[:,i] = (normals[:,i] - nmeans[i]) / nstds[i]
X = np.column_stack((cells, barycenters, normals))
Y = labels
# initialize batch of input and label
X_train = np.zeros([self.patch_size, X.shape[1]], dtype='float32')
Y_train = np.zeros([self.patch_size, Y.shape[1]], dtype='int32')
S1 = np.zeros([self.patch_size, self.patch_size], dtype='float32')
S2 = np.zeros([self.patch_size, self.patch_size], dtype='float32')
# calculate number of valid cells (tooth instead of gingiva)
positive_idx = np.argwhere(labels>0)[:, 0] #tooth idx
negative_idx = np.argwhere(labels==0)[:, 0] # gingiva idx
num_positive = len(positive_idx) # number of selected tooth cells
if num_positive > self.patch_size: # all positive_idx in this patch
positive_selected_idx = np.random.choice(positive_idx, size=self.patch_size, replace=False)
selected_idx = positive_selected_idx
else: # patch contains all positive_idx and some negative_idx
num_negative = self.patch_size - num_positive # number of selected gingiva cells
positive_selected_idx = np.random.choice(positive_idx, size=num_positive, replace=False)
negative_selected_idx = np.random.choice(negative_idx, size=num_negative, replace=False)
selected_idx = np.concatenate((positive_selected_idx, negative_selected_idx))
selected_idx = np.sort(selected_idx, axis=None)
X_train[:] = X[selected_idx, :]
Y_train[:] = Y[selected_idx, :]
# output to visualize
# mesh2 = Easy_Mesh()
# mesh2.cells = X_train[:, 0:9]
# mesh2.update_cell_ids_and_points()
# mesh2.cell_attributes['Normal'] = X_train[:, 12:15]
# mesh2.cell_attributes['Label'] = Y_train
# mesh2.to_vtp('tmp.vtp')
if torch.cuda.is_available():
TX = torch.as_tensor(X_train[:, 9:12], device='cuda')
TD = torch.cdist(TX, TX)
D = TD.cpu().numpy()
else:
D = distance_matrix(X_train[:, 9:12], X_train[:, 9:12])
S1[D<0.1] = 1.0
S1 = S1 / np.dot(np.sum(S1, axis=1, keepdims=True), np.ones((1, self.patch_size)))
S2[D<0.2] = 1.0
S2 = S2 / np.dot(np.sum(S2, axis=1, keepdims=True), np.ones((1, self.patch_size)))
X_train = X_train.transpose(1, 0)
Y_train = Y_train.transpose(1, 0)
sample = {'cells': torch.from_numpy(X_train), 'labels': torch.from_numpy(Y_train),
'A_S': torch.from_numpy(S1), 'A_L': torch.from_numpy(S2)}
return sample