diff --git a/src/deep_neurographs/machine_learning/heterograph_trainer.py b/src/deep_neurographs/machine_learning/gnn_trainer.py similarity index 96% rename from src/deep_neurographs/machine_learning/heterograph_trainer.py rename to src/deep_neurographs/machine_learning/gnn_trainer.py index 8dc5866..bd32650 100644 --- a/src/deep_neurographs/machine_learning/heterograph_trainer.py +++ b/src/deep_neurographs/machine_learning/gnn_trainer.py @@ -33,7 +33,7 @@ WEIGHT_DECAY = 1e-3 -class HeteroGraphTrainer: +class Trainer: """ Custom class that trains graph neural networks. @@ -107,9 +107,9 @@ def run(self, train_dataset_list, validation_dataset_list): # Train y, hat_y = [], [] self.model.train() - for graph_dataset in train_dataset_list: + for dataset in train_dataset_list: # Forward pass - hat_y_i, y_i = self.predict(graph_dataset.data) + hat_y_i, y_i = self.predict(dataset.data) loss = self.criterion(hat_y_i, y_i) self.writer.add_scalar("loss", loss, epoch) @@ -129,8 +129,8 @@ def run(self, train_dataset_list, validation_dataset_list): if epoch % 10 == 0: y, hat_y = [], [] self.model.eval() - for graph_dataset in validation_dataset_list: - hat_y_i, y_i = self.predict(graph_dataset.data) + for dataset in validation_dataset_list: + hat_y_i, y_i = self.predict(dataset.data) y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) test_score = self.compute_metrics(y, hat_y, "val", epoch) diff --git a/src/deep_neurographs/machine_learning/graph_trainer.py b/src/deep_neurographs/machine_learning/graph_trainer.py deleted file mode 100644 index cc9b9ef..0000000 --- a/src/deep_neurographs/machine_learning/graph_trainer.py +++ /dev/null @@ -1,420 +0,0 @@ -""" -Created on Sat April 12 11:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for training graph neural networks that classify edge proposals. - -""" - -from copy import deepcopy -from random import sample, shuffle - -import numpy as np -import torch -from sklearn.metrics import ( - accuracy_score, - f1_score, - precision_score, - recall_score, -) -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter -from torch_geometric.utils import subgraph - -from deep_neurographs.utils import gnn_util, ml_util - -# Training -LR = 1e-3 -MODEL_TYPE = "GraphNeuralNet" -N_EPOCHS = 200 -SCHEDULER_GAMMA = 0.5 -SCHEDULER_STEP_SIZE = 1000 -TEST_PERCENT = 0.15 -WEIGHT_DECAY = 1e-3 - -# Augmentation -MAX_PROPOSAL_DROPOUT = 0.1 -SCALING_FACTOR = 0.05 - - -class GraphTrainer: - """ - Custom class that trains graph neural networks. - - """ - - def __init__( - self, - model, - criterion, - lr=LR, - n_epochs=N_EPOCHS, - max_proposal_dropout=MAX_PROPOSAL_DROPOUT, - scaling_factor=SCALING_FACTOR, - weight_decay=WEIGHT_DECAY, - ): - """ - Constructs a GraphTrainer object. - - Parameters - ---------- - model : torch.nn.Module - Graph neural network. - criterion : torch.nn.Module._Loss - Loss function. - lr : float, optional - Learning rate. The default is the global variable LR. - n_epochs : int - Number of epochs. The default is the global variable N_EPOCHS. - weight_decay : float - Weight decay used in optimizer. The default is the global variable - WEIGHT_DECAY. - - Returns - ------- - None. - - """ - # Training - self.model = model # .to("cuda:0") - self.criterion = criterion - self.n_epochs = n_epochs - self.optimizer = torch.optim.Adam( - model.parameters(), lr=lr, weight_decay=weight_decay - ) - self.init_scheduler() - self.writer = SummaryWriter() - - # Augmentation - self.scaling_factor = scaling_factor - self.max_proposal_dropout = max_proposal_dropout - - def init_scheduler(self): - self.scheduler = StepLR( - self.optimizer, - step_size=SCHEDULER_STEP_SIZE, - gamma=SCHEDULER_GAMMA, - ) - - def run_on_graphs(self, datasets, augment=False): - """ - Trains a graph neural network in the case where "datasets" is a - dictionary of datasets such that each corresponds to a distinct graph. - - Parameters - ---------- - datasets : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. - - Returns - ------- - torch.nn.Module - Graph neural network that has been fit onto "datasets". - - """ - # Initializations - best_score = -np.inf - best_ckpt = None - - # Main - train_ids, test_ids = train_test_split(list(datasets.keys())) - for epoch in range(self.n_epochs): - # Train - y, hat_y = [], [] - self.model.train() - for graph_id in train_ids: - print(graph_id) - y_i, hat_y_i = self.train( - datasets[graph_id], epoch, augment=augment - ) - y.extend(gnn_util.toCPU(y_i)) - hat_y.extend(gnn_util.toCPU(hat_y_i)) - self.compute_metrics(y, hat_y, "train", epoch) - self.scheduler.step() - - # Test - if epoch % 10 == 0: - y, hat_y = [], [] - self.model.eval() - for graph_id in test_ids: - y_i, hat_y_i = self.forward(datasets[graph_id].data) - y.extend(gnn_util.toCPU(y_i)) - hat_y.extend(gnn_util.toCPU(hat_y_i)) - test_score = self.compute_metrics(y, hat_y, "val", epoch) - - # Check for best - if test_score > best_score: - best_score = test_score - best_ckpt = deepcopy(self.model.state_dict()) - self.model.load_state_dict(best_ckpt) - return self.model - - def run_on_graph(self): - """ - Trains a graph neural network in the case where "dataset" is a - graph that may contain multiple connected components. - - Parameters - ---------- - dataset : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. - - Returns - ------- - None - - """ - pass - - def train(self, dataset, epoch, augment=False): - """ - Performs the forward pass and backpropagation to update the model's - weights. - - Parameters - ---------- - data : GraphDataset - Graph dataset that corresponds to a single connected component. - epoch : int - Current epoch. - augment : bool, optional - Indication of whether to augment data. Default is False. - - Returns - ------- - torch.Tensor - Ground truth. - torch.Tensor - Prediction. - - """ - # Data augmentation (if applicable) - if self.augment: - data = self.augment(dataset) - else: - data = deepcopy(dataset.data) - - # Forward - y, hat_y = self.forward(data) - self.backpropagate(y, hat_y, epoch) - return y, hat_y - - def augment(self, dataset): - augmented_data = rescale_data(dataset, self.scaling_factor) - # augmented_data = proposal_dropout(data, self.max_proposal_dropout) - return augmented_data - - def forward(self, data): - """ - Runs "data" through "self.model" to generate a prediction. - - Parameters - ---------- - data : GraphDataset - Graph dataset that corresponds to a single connected component. - - Returns - ------- - torch.Tensor - Ground truth. - torch.Tensor - Prediction. - - """ - self.optimizer.zero_grad() - x, edge_index = gnn_util.get_inputs(data, MODEL_TYPE) - hat_y = self.model(x, edge_index) - y = data.y # .to("cuda:0", dtype=torch.float32) - return y, truncate(hat_y, y) - - def backpropagate(self, y, hat_y, epoch): - """ - Runs backpropagation to update the model's weights. - - Parameters - ---------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - epoch : int - Current epoch. - - Returns - ------- - None - - """ - loss = self.criterion(hat_y, y) - loss.backward() - self.optimizer.step() - self.writer.add_scalar("loss", loss, epoch) - - def compute_metrics(self, y, hat_y, prefix, epoch): - """ - Computes and logs evaluation metrics for binary classification. - - Parameters - ---------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - prefix : str - Prefix to be added to the metric names when logging. - epoch : int - Current epoch. - - Returns - ------- - float - F1 score. - - """ - # Initializations - y = np.array(y, dtype=int).tolist() - hat_y = get_predictions(hat_y) - - # Compute - accuracy = accuracy_score(y, hat_y) - accuracy_dif = accuracy - np.sum(y) / len(y) - precision = precision_score(y, hat_y) - recall = recall_score(y, hat_y) - f1 = f1_score(y, hat_y) - - # Log - self.writer.add_scalar(prefix + "_accuracy:", accuracy, epoch) - self.writer.add_scalar(prefix + "_accuracy_df:", accuracy_dif, epoch) - self.writer.add_scalar(prefix + "_precision:", precision, epoch) - self.writer.add_scalar(prefix + "_recall:", recall, epoch) - self.writer.add_scalar(prefix + "_f1:", f1, epoch) - return f1 - - -# -- util -- -def shuffler(my_list): - """ - Shuffles a list of items. - - Parameters - ---------- - my_list : list - List to be shuffled. - - Returns - ------- - list - Shuffled list. - - """ - shuffle(my_list) - return my_list - - -def train_test_split(graph_ids): - """ - Split a list of graph IDs into training and testing sets. - - Parameters - ---------- - graph_ids : list[str] - A list containing unique identifiers (IDs) for graphs. - - Returns - ------- - train_ids : list - A list containing IDs for the training set. - test_ids : list - A list containing IDs for the testing set. - - """ - n_test_examples = int(len(graph_ids) * TEST_PERCENT) - test_ids = ["block_000", "block_002"] # sample(graph_ids, n_test_examples) - train_ids = list(set(graph_ids) - set(test_ids)) - return train_ids, test_ids - - -def truncate(hat_y, y): - """ - Truncates "hat_y" so that this tensor has the same shape as "y". Note this - operation removes the predictions corresponding to branches so that loss - is computed over proposals. - - Parameters - ---------- - hat_y : torch.Tensor - Tensor to be truncated. - y : torch.Tensor - Tensor used as a reference. - - Returns - ------- - torch.Tensor - Truncated "hat_y". - - """ - return hat_y[: y.size(0), 0] - - -def get_predictions(hat_y, threshold=0.5): - """ - Generate binary predictions based on the input probabilities. - - Parameters - ---------- - hat_y : torch.Tensor - Predicted probabilities generated by "self.model". - threshold : float, optional - The threshold value for binary classification. The default is 0.5. - - Returns - ------- - list[int] - Binary predictions based on the given threshold. - - """ - return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() - - -def connected_components(data): - cc_list = [] - cc_idxs = torch.unique(data.edge_index[0], return_inverse=True)[1] - for i in range(cc_idxs.max().item() + 1): - cc_list.append(torch.nonzero(cc_idxs == i, as_tuple=False).view(-1)) - return cc_list - - -def rescale_data(dataset, scaling_factor): - # Get scaling factor - low = 1.0 - scaling_factor - high = 1.0 + scaling_factor - scaling_factor = torch.tensor(np.random.uniform(low=low, high=high)) - - # Rescale - n = count_proposals(dataset) - data = deepcopy(dataset.data) - data.x[0:n, 1] = scaling_factor * data.x[0:n, 1] - return data - - -def proposal_dropout(data, max_proposal_dropout): - n_dropout_edges = len(data.dropout_edges) // 2 - dropout_prob = np.random.uniform(low=0, high=max_proposal_dropout) - n_remove = int(dropout_prob * n_dropout_edges) - remove_edges = sample(data.dropout_edges, n_remove) - for edge in remove_edges: - reversed_edge = [edge[1], edge[0]] - edges_to_remove = torch.tensor([edge, reversed_edge], dtype=torch.long) - edges_mask = torch.all( - data.data.edge_index.T == edges_to_remove[:, None], dim=2 - ).any(dim=0) - data.data.edge_index = data.data.edge_index[:, ~edges_mask] - return data - - -def count_proposals(dataset): - return dataset.data.y.size(0) diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 963e2b8..c585d48 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -140,7 +140,6 @@ def __init__( self.check_missing_edge_type() self.init_edge_attrs(x_nodes) self.n_edge_attrs = n_edge_features(x_nodes) - def init_edges(self): """ diff --git a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py index 32fc446..e79abc3 100644 --- a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py +++ b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py @@ -16,7 +16,6 @@ from deep_neurographs.machine_learning import feature_generation as feats from deep_neurographs.utils import img_util - N_PROFILE_PTS = 16 NODE_PROFILE_DEPTH = 16 WINDOW = [5, 5, 5] diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index 7fad91a..f400840 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -17,11 +17,12 @@ from torch.nn import BCEWithLogitsLoss from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning.gnn_trainer import Trainer from deep_neurographs.utils import img_util, ml_util, util from deep_neurographs.utils.graph_util import GraphLoader -class Trainer: +class TrainingPipeline: """ Class that is used to train a machine learning model that classifies proposals. @@ -53,8 +54,8 @@ def __init__( self.gt_graphs = list() self.pred_graphs = list() self.imgs = dict() - self.train_dataset = list() - self.validation_dataset = list() + self.train_dataset_list = list() + self.validation_dataset_list = list() # Train parameters self.criterion = criterion if criterion else BCEWithLogitsLoss() @@ -73,10 +74,10 @@ def n_examples(self): return len(self.gt_graphs) def n_train_examples(self): - return len(self.train_dataset) + return len(self.train_dataset_list) def n_validation_samples(self): - return len(self.validation_dataset) + return len(self.validation_dataset_list) def set_validation_idxs(self): if self.validation_ids is None: @@ -131,9 +132,24 @@ def load_img(self, path, sample_id): # --- main pipeline --- def run(self): + # Initialize training data self.generate_proposals() self.generate_features() - self.train_model() + + # Train model + trainer = Trainer( + self.model, + self.criterion, + lr=self.ml_config.lr, + n_epochs=self.ml_config.n_epochs, + ) + self.model = trainer.run( + self.train_dataset_list, self.validation_dataset_list + ) + + # Save model (if applicable) + if self.save_model_bool: + self.save_model() def generate_proposals(self): print("sample_id - example_id - # proposals - % accepted") @@ -184,15 +200,12 @@ def generate_features(self): computation_graph=proposals_dict["graph"] ) if i in self.validation_ids: - self.validation_dataset.append(dataset) + self.validation_dataset_list.append(dataset) else: - self.train_dataset.append(dataset) - - def train_model(self): - pass + self.train_dataset_list.append(dataset) - def save_model(self, model): + def save_model(self): name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') extension = ".pth" if "Net" in self.model_type else ".joblib" path = os.path.join(self.output_dir, name + extension) - util.save_model(path, model) + ml_util.save_model(path, self.model, self.model_type)