From 851aabe87c2b119fbf3785d6928918102f38ea53 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 27 Sep 2024 18:06:14 +0000 Subject: [PATCH] refactor: combined train engine and pipeline --- ...{run_pipeline.py => inference_pipeline.py} | 2 +- .../groundtruth_generation.py | 2 +- .../machine_learning/heterograph_models.py | 6 +- .../{gnn_trainer.py => train.py} | 209 ++++++++++++++++- .../machine_learning/trainer.py | 157 ------------- src/deep_neurographs/train_pipeline.py | 212 ------------------ 6 files changed, 209 insertions(+), 379 deletions(-) rename src/deep_neurographs/{run_pipeline.py => inference_pipeline.py} (99%) rename src/deep_neurographs/machine_learning/{gnn_trainer.py => train.py} (50%) delete mode 100644 src/deep_neurographs/machine_learning/trainer.py delete mode 100644 src/deep_neurographs/train_pipeline.py diff --git a/src/deep_neurographs/run_pipeline.py b/src/deep_neurographs/inference_pipeline.py similarity index 99% rename from src/deep_neurographs/run_pipeline.py rename to src/deep_neurographs/inference_pipeline.py index 4b4db4b..b7628bc 100644 --- a/src/deep_neurographs/run_pipeline.py +++ b/src/deep_neurographs/inference_pipeline.py @@ -38,7 +38,7 @@ from deep_neurographs.utils.graph_util import GraphLoader -class GraphTracePipeline: +class InferencePipeline: """ Class that executes the full GraphTrace inference pipeline. diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index 0095630..ad12616 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -19,7 +19,7 @@ from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import util -ALIGNED_THRESHOLD = 3.5 +ALIGNED_THRESHOLD = 4 MIN_INTERSECTION = 10 diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index a27cd52..fd9794e 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -16,7 +16,7 @@ from torch_geometric.nn import GATv2Conv as GATConv from torch_geometric.nn import HEATConv, HeteroConv, Linear -from deep_neurographs.machine_learning import heterograph_feature_generation +from deep_neurographs import machine_learning as ml CONV_TYPES = ["GATConv", "GCNConv"] DROPOUT = 0.3 @@ -43,8 +43,8 @@ def __init__( """ super().__init__() # Feature vector sizes - node_dict = heterograph_feature_generation.n_node_features() - edge_dict = heterograph_feature_generation.n_edge_features() + node_dict = ml.heterograph_feature_generation.n_node_features() + edge_dict = ml.heterograph_feature_generation.n_edge_features() hidden_dim = scale_hidden_dim * np.max(list(node_dict.values())) # Linear layers diff --git a/src/deep_neurographs/machine_learning/gnn_trainer.py b/src/deep_neurographs/machine_learning/train.py similarity index 50% rename from src/deep_neurographs/machine_learning/gnn_trainer.py rename to src/deep_neurographs/machine_learning/train.py index 2c17c2a..99ed883 100644 --- a/src/deep_neurographs/machine_learning/gnn_trainer.py +++ b/src/deep_neurographs/machine_learning/train.py @@ -4,13 +4,14 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Routines for training heterogeneous graph neural networks that classify -edge proposals. +Routines for training machine learning models that classify proposals. """ +import os from copy import deepcopy -from random import shuffle +from datetime import datetime +from random import sample, shuffle import numpy as np import torch @@ -20,11 +21,14 @@ precision_score, recall_score, ) +from torch.nn import BCEWithLogitsLoss from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from deep_neurographs.utils import gnn_util, ml_util +from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.utils import gnn_util, img_util, ml_util, util from deep_neurographs.utils.gnn_util import toCPU +from deep_neurographs.utils.graph_util import GraphLoader LR = 1e-3 N_EPOCHS = 200 @@ -33,7 +37,197 @@ WEIGHT_DECAY = 1e-3 -class Trainer: +class TrainPipeline: + """ + Class that is used to train a machine learning model that classifies + proposals. + + """ + def __init__( + self, + config, + model, + model_type, + criterion=None, + output_dir=None, + validation_ids=None, + save_model_bool=True, + ): + # Check for parameter errors + if save_model_bool and not output_dir: + raise ValueError("Must provide output_dir to save model.") + + # Set class attributes + self.idx_to_ids = list() + self.model = model + self.model_type = model_type + self.output_dir = output_dir + self.save_model_bool = save_model_bool + self.validation_ids = validation_ids + + # Set data structures for training examples + self.gt_graphs = list() + self.pred_graphs = list() + self.imgs = dict() + self.train_dataset_list = list() + self.validation_dataset_list = list() + + # Train parameters + self.criterion = criterion if criterion else BCEWithLogitsLoss() + self.validation_ids = validation_ids + + # Extract config settings + self.graph_config = config.graph_config + self.ml_config = config.ml_config + self.graph_loader = GraphLoader( + min_size=self.graph_config.min_size, + progress_bar=False, + ) + + # --- getters/setters --- + def n_examples(self): + return len(self.gt_graphs) + + def n_train_examples(self): + return len(self.train_dataset_list) + + def n_validation_samples(self): + return len(self.validation_dataset_list) + + def set_validation_idxs(self): + if self.validation_ids is None: + k = int(self.ml_config.validation_split * self.n_examples()) + self.validation_idxs = sample(np.arange(self.n_examples), k) + else: + self.validation_idxs = list() + for ids in self.validation_ids: + for i in range(self.n_examples()): + same = all([ids[k] == self.idx_to_ids[i][k] for k in ids]) + if same: + self.validation_idxs.append(i) + assert len(self.validation_idxs) > 0, "No validation data!" + + # --- loaders --- + def load_example( + self, + gt_pointer, + pred_pointer, + sample_id, + example_id=None, + pred_id=None, + metadata_path=None, + ): + # Read metadata + if metadata_path: + origin, shape = util.read_metadata(metadata_path) + else: + origin, shape = None, None + + # Load graphs + self.gt_graphs.append(self.graph_loader.run(gt_pointer)) + self.pred_graphs.append( + self.graph_loader.run( + pred_pointer, + img_patch_origin=origin, + img_patch_shape=shape, + ) + ) + + # Set example ids + self.idx_to_ids.append( + { + "sample_id": sample_id, + "example_id": example_id, + "pred_id": pred_id, + } + ) + + def load_img(self, path, sample_id): + if sample_id not in self.imgs: + self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") + + # --- main pipeline --- + def run(self): + # Initialize training data + self.set_validation_idxs() + self.generate_proposals() + self.generate_features() + + # Train model + train_engine = TrainEngine( + self.model, + self.criterion, + lr=self.ml_config.lr, + n_epochs=self.ml_config.n_epochs, + ) + self.model = train_engine.run( + self.train_dataset_list, self.validation_dataset_list + ) + + # Save model (if applicable) + if self.save_model_bool: + self.save_model() + + def generate_proposals(self): + print("sample_id - example_id - # proposals - % accepted") + for i in range(self.n_examples()): + # Run + self.pred_graphs[i].generate_proposals( + self.graph_config.search_radius, + complex_bool=self.graph_config.complex_bool, + groundtruth_graph=self.gt_graphs[i], + long_range_bool=self.graph_config.long_range_bool, + progress_bar=False, + proposals_per_leaf=self.graph_config.proposals_per_leaf, + trim_endpoints_bool=self.graph_config.trim_endpoints_bool, + ) + + # Report results + sample_id = self.idx_to_ids[i]["sample_id"] + example_id = self.idx_to_ids[i]["example_id"] + n_proposals = self.pred_graphs[i].n_proposals() + n_targets = len(self.pred_graphs[i].target_edges) + p_accepts = round(n_targets / n_proposals, 4) + print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") + + def generate_features(self): + for i in range(self.n_examples()): + # Get proposals + proposals_dict = { + "proposals": self.pred_graphs[i].list_proposals(), + "graph": self.pred_graphs[i].copy_graph() + } + + # Generate features + sample_id = self.idx_to_ids[i]["sample_id"] + features = feature_generation.run( + self.pred_graphs[i], + self.imgs[sample_id], + self.model_type, + proposals_dict, + self.graph_config.search_radius, + ) + + # Initialize train and validation datasets + dataset = ml_util.init_dataset( + self.pred_graphs[i], + features, + self.model_type, + computation_graph=proposals_dict["graph"] + ) + if i in self.validation_idxs: + self.validation_dataset_list.append(dataset) + else: + self.train_dataset_list.append(dataset) + + def save_model(self): + name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') + extension = ".pth" if "Net" in self.model_type else ".joblib" + path = os.path.join(self.output_dir, name + extension) + ml_util.save_model(path, self.model, self.model_type) + + +class TrainEngine: """ Custom class that trains graph neural networks. @@ -205,6 +399,11 @@ def compute_metrics(self, y, hat_y, prefix, epoch): return f1 +def fit_random_forest(model, dataset): + model.fit(dataset.data.x, dataset.data.y) + return model + + # -- util -- def shuffler(my_list): """ diff --git a/src/deep_neurographs/machine_learning/trainer.py b/src/deep_neurographs/machine_learning/trainer.py deleted file mode 100644 index 954f125..0000000 --- a/src/deep_neurographs/machine_learning/trainer.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Created on Sat November 04 15:30:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for training models that classify edge proposals. - -""" - -import logging - -import lightning.pytorch as pl -import torch -import torch.nn as nn -import torch.utils.data as torch_data -from lightning.pytorch.callbacks import ModelCheckpoint -from torch.nn.functional import sigmoid -from torch.utils.data import DataLoader -from torcheval.metrics.functional import ( - binary_f1_score, - binary_precision, - binary_recall, -) - -logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) - -BATCH_SIZE = 32 -SHUFFLE = True -SUPPORTED_MODELS = [ - "AdaBoost", - "RandomForest", - "FeedForwardNet", - "ConvNet", - "MultiModalNet", -] - - -def fit_model(model, dataset): - model.fit(dataset.data.x, dataset.data.y) - return model - - -def fit_deep_model( - model, - dataset, - batch_size=BATCH_SIZE, - criterion=None, - logger=False, - lr=1e-3, - max_epochs=1000, -): - """ - Fits a neural network to a dataset. - - Parameters - ---------- - model : ... - ... - dataset : ... - ... - lr : float, optional - Learning rate to be used if model is a neural network. The default is - 1e-3. - logger : bool, optional - Indication of whether to log performance stats while neural network - trains. The default is False. - max_epochs : int, optional - Maximum number of epochs used to train neural network. The default is - 50. - - Returns - ------- - ... - """ - # Load data - train_set, valid_set = random_split(dataset.data) - train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) - valid_loader = DataLoader(valid_set, batch_size=batch_size) - - # Configure trainer - lit_model = LitModel(criterion=criterion, model=model, lr=lr) - ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_f1", mode="max") - - # Fit model - pylightning_trainer = pl.Trainer( - accelerator="gpu", - callbacks=[ckpt_callback], - devices=1, - enable_model_summary=False, - enable_progress_bar=False, - logger=logger, - log_every_n_steps=1, - max_epochs=max_epochs, - ) - pylightning_trainer.fit(lit_model, train_loader, valid_loader) - - # Return best model - ckpt = torch.load(ckpt_callback.best_model_path) - lit_model.model.load_state_dict(ckpt["state_dict"]) - return lit_model.model - - -def random_split(train_set, train_ratio=0.8): - train_set_size = int(len(train_set) * train_ratio) - valid_set_size = len(train_set) - train_set_size - return torch_data.random_split(train_set, [train_set_size, valid_set_size]) - - -# -- Lightning Module -- -class LitModel(pl.LightningModule): - def __init__(self, criterion=None, model=None, lr=1e-3): - super().__init__() - self.model = model - self.lr = lr - if criterion: - self.criterion = criterion - else: - pos_weight = torch.tensor([1.0], device=0) - self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) - - def forward(self, batch): - x = self.get_example(batch, "inputs") - return self.model(x) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def training_step(self, batch, batch_idx): - X = self.get_example(batch, "inputs") - y = self.get_example(batch, "targets") - y_hat = self.model(X) - - loss = self.criterion(y_hat, y) - self.log("train_loss", loss) - self.compute_stats(y_hat, y, prefix="train_") - return loss - - def validation_step(self, batch, batch_idx): - X = self.get_example(batch, "inputs") - y = self.get_example(batch, "targets") - y_hat = self.model(X) - self.compute_stats(y_hat, y, prefix="val_") - - def compute_stats(self, y_hat, y, prefix=""): - y_hat = torch.flatten(sigmoid(y_hat)) - y = torch.flatten(y).to(torch.int) - self.log(prefix + "precision", binary_precision(y_hat, y)) - self.log(prefix + "recall", binary_recall(y_hat, y)) - self.log(prefix + "f1", binary_f1_score(y_hat, y)) - - def get_example(self, batch, key): - return batch[key] - - def state_dict(self, destination=None, prefix="", keep_vars=False): - return self.model.state_dict(destination, prefix + "", keep_vars) diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py deleted file mode 100644 index c25bb77..0000000 --- a/src/deep_neurographs/train_pipeline.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -Created on Sat Sept 16 11:30:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - - -This script trains the GraphTrace inference pipeline. - -""" - -import os -from datetime import datetime -from random import sample - -import numpy as np -from torch.nn import BCEWithLogitsLoss - -from deep_neurographs.machine_learning import feature_generation -from deep_neurographs.machine_learning.gnn_trainer import Trainer -from deep_neurographs.utils import img_util, ml_util, util -from deep_neurographs.utils.graph_util import GraphLoader - - -class TrainingPipeline: - """ - Class that is used to train a machine learning model that classifies - proposals. - - """ - def __init__( - self, - config, - model, - model_type, - criterion=None, - output_dir=None, - validation_ids=None, - save_model_bool=True, - ): - # Check for parameter errors - if save_model_bool and not output_dir: - raise ValueError("Must provide output_dir to save model.") - - # Set class attributes - self.idx_to_ids = list() - self.model = model - self.model_type = model_type - self.output_dir = output_dir - self.save_model_bool = save_model_bool - self.validation_ids = validation_ids - - # Set data structures for training examples - self.gt_graphs = list() - self.pred_graphs = list() - self.imgs = dict() - self.train_dataset_list = list() - self.validation_dataset_list = list() - - # Train parameters - self.criterion = criterion if criterion else BCEWithLogitsLoss() - self.validation_ids = validation_ids - - # Extract config settings - self.graph_config = config.graph_config - self.ml_config = config.ml_config - self.graph_loader = GraphLoader( - min_size=self.graph_config.min_size, - progress_bar=False, - ) - - # --- getters/setters --- - def n_examples(self): - return len(self.gt_graphs) - - def n_train_examples(self): - return len(self.train_dataset_list) - - def n_validation_samples(self): - return len(self.validation_dataset_list) - - def set_validation_idxs(self): - if self.validation_ids is None: - k = int(self.ml_config.validation_split * self.n_examples()) - self.validation_idxs = sample(np.arange(self.n_examples), k) - else: - self.validation_idxs = list() - for ids in self.validation_ids: - for i in range(self.n_examples()): - same = all([ids[k] == self.idx_to_ids[i][k] for k in ids]) - if same: - self.validation_idxs.append(i) - - # --- loaders --- - def load_example( - self, - gt_pointer, - pred_pointer, - sample_id, - example_id=None, - pred_id=None, - metadata_path=None, - ): - # Read metadata - if metadata_path: - origin, shape = util.read_metadata(metadata_path) - else: - origin, shape = None, None - - # Load graphs - self.gt_graphs.append(self.graph_loader.run(gt_pointer)) - self.pred_graphs.append( - self.graph_loader.run( - pred_pointer, - img_patch_origin=origin, - img_patch_shape=shape, - ) - ) - - # Set example ids - self.idx_to_ids.append( - { - "sample_id": sample_id, - "example_id": example_id, - "pred_id": pred_id, - } - ) - - def load_img(self, path, sample_id): - if sample_id not in self.imgs: - self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") - - # --- main pipeline --- - def run(self): - # Initialize training data - self.generate_proposals() - self.generate_features() - self.set_validation_idxs() - assert len(self.validation_dataset_list) > 0, "No validation data!" - - # Train model - trainer = Trainer( - self.model, - self.criterion, - lr=self.ml_config.lr, - n_epochs=self.ml_config.n_epochs, - ) - self.model = trainer.run( - self.train_dataset_list, self.validation_dataset_list - ) - - # Save model (if applicable) - if self.save_model_bool: - self.save_model() - - def generate_proposals(self): - print("sample_id - example_id - # proposals - % accepted") - for i in range(self.n_examples()): - # Run - self.pred_graphs[i].generate_proposals( - self.graph_config.search_radius, - complex_bool=self.graph_config.complex_bool, - groundtruth_graph=self.gt_graphs[i], - long_range_bool=self.graph_config.long_range_bool, - progress_bar=False, - proposals_per_leaf=self.graph_config.proposals_per_leaf, - trim_endpoints_bool=self.graph_config.trim_endpoints_bool, - ) - - # Report results - sample_id = self.idx_to_ids[i]["sample_id"] - example_id = self.idx_to_ids[i]["example_id"] - n_proposals = self.pred_graphs[i].n_proposals() - n_targets = len(self.pred_graphs[i].target_edges) - p_accepts = round(n_targets / n_proposals, 4) - print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") - - def generate_features(self): - for i in range(self.n_examples()): - # Get proposals - proposals_dict = { - "proposals": self.pred_graphs[i].list_proposals(), - "graph": self.pred_graphs[i].copy_graph() - } - - # Generate features - sample_id = self.idx_to_ids[i]["sample_id"] - features = feature_generation.run( - self.pred_graphs[i], - self.imgs[sample_id], - self.model_type, - proposals_dict, - self.graph_config.search_radius, - ) - - # Initialize train and validation datasets - dataset = ml_util.init_dataset( - self.pred_graphs[i], - features, - self.model_type, - computation_graph=proposals_dict["graph"] - ) - if i in self.validation_idxs: - self.validation_dataset_list.append(dataset) - else: - self.train_dataset_list.append(dataset) - - def save_model(self): - name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') - extension = ".pth" if "Net" in self.model_type else ".joblib" - path = os.path.join(self.output_dir, name + extension) - ml_util.save_model(path, self.model, self.model_type)