diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index b9d0d0b..ba191a4 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -188,7 +188,7 @@ def smooth_branch(xyz, s=None): Returns ------- - xyz : numpy.ndarray + numpy.ndarray Smoothed points. """ @@ -199,7 +199,7 @@ def smooth_branch(xyz, s=None): return xyz.astype(np.float32) -def fit_spline(xyz, s=None): +def fit_spline(xyz, k=3, s=None): """ Fits a cubic spline to an array containing xyz coordinates. @@ -222,9 +222,9 @@ def fit_spline(xyz, s=None): """ s = xyz.shape[0] / 10 if not s else xyz.shape[0] / s t = np.linspace(0, 1, xyz.shape[0]) - spline_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3) - spline_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3) - spline_z = UnivariateSpline(t, xyz[:, 2], s=s, k=3) + spline_x = UnivariateSpline(t, xyz[:, 0], k=k, s=s) + spline_y = UnivariateSpline(t, xyz[:, 1], k=k, s=s) + spline_z = UnivariateSpline(t, xyz[:, 2], k=k, s=s) return spline_x, spline_y, spline_z @@ -245,8 +245,9 @@ def sample_curve(xyz_arr, n_pts): Resampled points along curve. """ + k = 1 if xyz_arr.shape[0] <= 3 else 3 t = np.linspace(0, 1, n_pts) - spline_x, spline_y, spline_z = fit_spline(xyz_arr, s=0) + spline_x, spline_y, spline_z = fit_spline(xyz_arr, k=k, s=0) xyz = np.column_stack((spline_x(t), spline_y(t), spline_z(t))) return xyz.astype(int) diff --git a/src/deep_neurographs/machine_learning/heterogeneous_graph_datasets.py b/src/deep_neurographs/machine_learning/heterogeneous_graph_datasets.py index 90dbb5e..b631cff 100644 --- a/src/deep_neurographs/machine_learning/heterogeneous_graph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterogeneous_graph_datasets.py @@ -164,11 +164,11 @@ def init_edge_attrs(self, x_nodes): """ # Proposal edges edge_type = ("proposal", "to", "proposal") - attrs = self.set_edge_attrs(x_nodes, edge_type, self.idxs_proposals) - # --> set attr + self.set_edge_attrs(x_nodes, edge_type, self.idxs_proposals) # Branch edges edge_type = ("branch", "to", "branch") + self.set_edge_attrs(x_nodes, edge_type, self.idxs_branches) # Branch-Proposal edges edge_type = ("branch", "to", "proposal") @@ -270,9 +270,9 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_mapping): e1, e2 = self.data[edge_type][:, i] v = node_intersection(idx_mapping, e1, e2) attrs.append(x_nodes[v]) - print(v) - print(attrs) - stop + arrs = torch.tensor(np.array(attrs), dtype=DTYPE) + self.data[edge_type].edge_attr = arrs + # -- utils -- def init_idxs(idxs):