Skip to content

Commit

Permalink
black reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Nov 7, 2023
1 parent 54b3654 commit 97d0a35
Show file tree
Hide file tree
Showing 19 changed files with 50 additions and 1,048 deletions.
Binary file removed src/data/Planetoid/Cora/processed/data.pt
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/processed/pre_filter.pt
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/processed/pre_transform.pt
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/raw/ind.cora.allx
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/raw/ind.cora.ally
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/raw/ind.cora.graph
Binary file not shown.
1,000 changes: 0 additions & 1,000 deletions src/data/Planetoid/Cora/raw/ind.cora.test.index

This file was deleted.

Binary file removed src/data/Planetoid/Cora/raw/ind.cora.tx
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/raw/ind.cora.ty
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/raw/ind.cora.x
Binary file not shown.
Binary file removed src/data/Planetoid/Cora/raw/ind.cora.y
Binary file not shown.
23 changes: 10 additions & 13 deletions src/deep_neurographs/deep_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,15 @@ def __init__(self):
None
"""
self.blur = tio.RandomBlur(std=(0, 0.4))
self.noise = tio.RandomNoise(std=(0, 0.03))
self.apply_geometric = tio.Compose(
{
# tio.RandomFlip(axes=(0, 1, 2)),
tio.RandomAffine(
degrees=20, scales=(0.8, 1), image_interpolation="nearest"
)
}
self.transform = tio.Compose(
[
tio.RandomBlur(std=(0, 0.4)),
tio.RandomNoise(std=(0, 0.03)),
tio.RandomFlip(axes=(0, 1, 2)),
# tio.RandomAffine(
# degrees=20, scales=(0.8, 1), image_interpolation="nearest"
# )
]
)

def run(self, arr):
Expand All @@ -271,10 +271,7 @@ def run(self, arr):
Transformed array after being run through augmentation pipeline.
"""
arr = self.blur(arr)
arr = self.noise(arr)
arr = self.apply_geometric(arr)
return arr
return self.transform(arr)


def reformat(arr):
Expand Down
16 changes: 12 additions & 4 deletions src/deep_neurographs/deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn.functional as F
import torch.utils.data as torch_data
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.profilers import PyTorchProfiler
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from torch.utils.data import DataLoader
from torcheval.metrics.functional import (
Expand All @@ -26,15 +27,14 @@
binary_recall,
)

from deep_neurographs import utils
from deep_neurographs.deep_learning import datasets as ds
from deep_neurographs.deep_learning import models

logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)


BATCH_SIZE = 32
NUM_WORKERS = 1
NUM_WORKERS = 0
SHUFFLE = True
SUPPORTED_CLFS = [
"AdaBoost",
Expand Down Expand Up @@ -90,6 +90,7 @@ def train_network(
logger=True,
max_epochs=100,
model_summary=True,
profile=False,
progress_bar=True,
):
# Load data
Expand All @@ -101,14 +102,20 @@ def train_network(
shuffle=SHUFFLE,
)
valid_loader = DataLoader(
valid_set, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE
valid_set,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=True,
)

# Fit model
# Configure trainer
model = LitNeuralNet(net)
checkpoint_callback = ModelCheckpoint(
save_top_k=1, monitor="val_f1", mode="max"
)
profiler = PyTorchProfiler() if profile else None

# Fit model
trainer = pl.Trainer(
accelerator="gpu",
callbacks=[checkpoint_callback],
Expand All @@ -117,6 +124,7 @@ def train_network(
enable_progress_bar=progress_bar,
logger=logger,
max_epochs=max_epochs,
profiler=profiler,
)
trainer.fit(model, train_loader, valid_loader)
return model
Expand Down
30 changes: 15 additions & 15 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Created on Sat November 04 15:30:00 2023
@author: Anna Grim
@email: [email protected]
Class of graphs that are built from swc files.
"""

import os

import networkx as nx
Expand All @@ -6,7 +16,7 @@
from scipy.spatial import KDTree

from deep_neurographs import swc_utils, utils
from deep_neurographs.geometry_utils import dist, make_line
from deep_neurographs.geometry_utils import dist


class DenseGraph:
Expand Down Expand Up @@ -69,27 +79,17 @@ def check_aligned(self, pred_xyz_i, pred_xyz_j):
if self.xyz_to_swc[xyz_i] != self.xyz_to_swc[xyz_j]:
return False

# Compare pred and target distances
# Compute distances
pred_xyz_i = np.array(pred_xyz_i)
pred_xyz_j = np.array(pred_xyz_j)
pred_dist = dist(pred_xyz_i, pred_xyz_j)

target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j)
target_dist = max(target_dist, 1)

# Check criteria
ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist)
if ratio < 0.5 and pred_dist > 15:
return False
# elif ratio < 0.2:
# return False

# Compare projected predicted path
proj_dists = []
proj_nodes = set()
for xyz in make_line(pred_xyz_i, pred_xyz_j, len(target_path)):
proj_xyz, proj_d = self.get_projection(xyz)
swc = self.xyz_to_swc[tuple(proj_xyz)]
proj_nodes.add(self.xyz_to_node[swc][tuple(proj_xyz)])
proj_dists.append(proj_d)

return True
else:
return True
10 changes: 5 additions & 5 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def generate_mutable_features(
"""
features = {"skel": generate_mutable_skel_features(neurograph)}
if img_path and img_profile:
features["img"] = generate_mutable_img_profile_features(
neurograph, img_path, anisotropy=anisotropy
)
elif img_path and not img_profile:
if img_path and labels_path:
features["img"] = generate_mutable_img_chunk_features(
neurograph, img_path, labels_path, anisotropy=anisotropy
)
elif img_path and img_profile:
features["img"] = generate_mutable_img_profile_features(
neurograph, img_path, anisotropy=anisotropy
)
return features


Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_profile(
def fill_path(img, path, val=-1):
for xyz in path:
x, y, z = tuple(np.round(xyz).astype(int))
img[x - 1 : x + 1, y - 1 : y + 1, z - 1 : z + 1] = val
img[(x - 1) : x + 1, (y - 1) : y + 1, (z - 1) : z + 1] = val
return img


Expand Down
3 changes: 0 additions & 3 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@

import os

import torch
from torch_geometric.data import Data

from deep_neurographs import s3_utils, swc_utils, utils
from deep_neurographs.neurograph import NeuroGraph

Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def init_targets(self, target_neurograph):
proj_xyz_j, d_j = target_neurograph.get_projection(xyz_j)

# Check criteria
if d_i > 9 or d_j > 9:
if d_i > 8 or d_j > 8:
continue
elif self.check_cycle((i, j)):
continue
Expand Down
12 changes: 6 additions & 6 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,18 @@ def open_tensorstore(path):

def read_img_chunk(img, xyz, shape):
return img[
xyz[2] - shape[2] // 2 : xyz[2] + shape[2] // 2,
xyz[1] - shape[1] // 2 : xyz[1] + shape[1] // 2,
xyz[0] - shape[0] // 2 : xyz[0] + shape[0] // 2,
(xyz[2] - shape[2] // 2) : xyz[2] + shape[2] // 2,
(xyz[1] - shape[1] // 2) : xyz[1] + shape[1] // 2,
(xyz[0] - shape[0] // 2) : xyz[0] + shape[0] // 2,
].transpose(2, 1, 0)


def read_tensorstore(ts_arr, xyz, shape):
arr = (
ts_arr[
xyz[0] - shape[0] // 2 : xyz[0] + shape[0] // 2,
xyz[1] - shape[1] // 2 : xyz[1] + shape[1] // 2,
xyz[2] - shape[2] // 2 : xyz[2] + shape[2] // 2,
(xyz[0] - shape[0] // 2) : xyz[0] + shape[0] // 2,
(xyz[1] - shape[1] // 2) : xyz[1] + shape[1] // 2,
(xyz[2] - shape[2] // 2) : xyz[2] + shape[2] // 2,
]
.read()
.result()
Expand Down

0 comments on commit 97d0a35

Please sign in to comment.