From 8c6c306d01f147f34a21b6745f464f4b9d5adfff Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 22 Dec 2024 20:57:26 +0100 Subject: [PATCH] Colab notebooks update (#103) * Fix use of deprecated arg in colab training * Refactor model save name path + comment wandb cell * Update Colab_WNet3D_training.ipynb * Improve logging in Colab * Subclass WnetTraininWorker to avoid duplication * Remove strict channel first * Add missing channel_dim, remove strict_check=False * Update worker_training.py * Update worker_training.py * Disable strict checks for channelfirstd * Update worker_training.py * Temp disable channel first * Fix init of Colab worker * Move issues with transforms to colab script + disable pad/channelfirst * Enable ChannelFirst again * Remove strict_check = False in original worker Seems to be a Colab-specific issue * Remove redundant code + Colab notebook tweaks * Revert wandb check * Update docs + Colab inference * Update training_wnet.rst * Update Colab_WNet3D_training.ipynb * update / WIP * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * nearly final! * exec * final --------- Co-authored-by: Mackenzie Mathis Co-authored-by: Mackenzie Mathis --- docs/source/guides/training_wnet.rst | 21 +- .../code_models/worker_training.py | 20 +- .../dev_scripts/colab_training.py | 676 +----------- notebooks/Colab_WNet3D_training.ipynb | 973 +++++++++--------- notebooks/Colab_inference_demo.ipynb | 261 +++-- 5 files changed, 689 insertions(+), 1262 deletions(-) diff --git a/docs/source/guides/training_wnet.rst b/docs/source/guides/training_wnet.rst index 21fff524..359dc321 100644 --- a/docs/source/guides/training_wnet.rst +++ b/docs/source/guides/training_wnet.rst @@ -18,24 +18,21 @@ The WNet3D **does not require a large amount of data to train**, but **choosing You may find below some guidelines, based on our own data and testing. -The WNet3D is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background. - -The WNet3D is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background. +The WNet3D is a self-supervised learning approach for 3D cell segmentation, and relies on the assumption that structural and morphological features of cells can be inferred directly from unlabeled data. This involves leveraging inherent properties such as spatial coherence and local contrast in imaging volumes to distinguish cellular structures. This approach assumes that meaningful representations of cellular boundaries and nuclei can emerge solely from raw 3D volumes. Thus, we strongly recommend that you use WNet3D on stacks that have clear foreground/background segregation and limited noise. Even if your final samples have noise, it is best to train on data that is as clean as you can. .. important:: For optimal performance, the following should be avoided for training: - - Images with very large, bright regions - - Almost-empty and empty images - - Images with large empty regions or "holes" + - Images with over-exposed pixels/artifacts you do not want to be learned! + - Almost-empty and/or fully empty images, especially if noise is present (it will learn to segment very small objects!). - However, the model may be accomodate: + However, the model may accomodate: - - Uneven brightness distribution - - Varied object shapes and radius - - Noisy images - - Uneven illumination across the image + - Uneven brightness distribution in your image! + - Varied object shapes and radius! + - Noisy images (as long as resolution is sufficient and boundaries are clear)! + - Uneven illumination across the image! For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement. @@ -88,7 +85,7 @@ Common issues troubleshooting If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub. -- **The NCuts loss "explodes" after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten. +- **The NCuts loss "explodes" upward after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten. - **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss. diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 05c576ac..e6d3173b 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -1,4 +1,5 @@ """Contains the workers used to train the models.""" + import platform import time from abc import abstractmethod @@ -200,7 +201,10 @@ def get_patch_dataset(self, train_transforms): patch_func = Compose( [ LoadImaged(keys=["image"], image_only=True), - EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"), + EnsureChannelFirstd( + keys=["image"], + channel_dim="no_channel", + ), RandSpatialCropSamplesd( keys=["image"], roi_size=( @@ -235,7 +239,8 @@ def get_dataset_eval(self, eval_dataset_dict): [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd( - keys=["image", "label"], channel_dim="no_channel" + keys=["image", "label"], + channel_dim="no_channel", ), # RandSpatialCropSamplesd( # keys=["image", "label"], @@ -280,7 +285,10 @@ def get_dataset(self, train_transforms): load_single_images = Compose( [ LoadImaged(keys=["image"]), - EnsureChannelFirstd(keys=["image"]), + EnsureChannelFirstd( + keys=["image"], + channel_dim="no_channel", + ), Orientationd(keys=["image"], axcodes="PLI"), SpatialPadd( keys=["image"], @@ -1345,9 +1353,9 @@ def get_patch_loader_func(num_samples): ) sample_loader_eval = get_patch_loader_func(num_val_samples) else: - num_train_samples = ( - num_val_samples - ) = self.config.num_samples + num_train_samples = num_val_samples = ( + self.config.num_samples + ) sample_loader_train = get_patch_loader_func( num_train_samples diff --git a/napari_cellseg3d/dev_scripts/colab_training.py b/napari_cellseg3d/dev_scripts/colab_training.py index a5020fec..79bcfdbb 100644 --- a/napari_cellseg3d/dev_scripts/colab_training.py +++ b/napari_cellseg3d/dev_scripts/colab_training.py @@ -1,64 +1,60 @@ """Script to run WNet training in Google Colab.""" + import time from pathlib import Path +from typing import TYPE_CHECKING -import torch -import torch.nn as nn +from monai.data import CacheDataset # MONAI -from monai.data import ( - CacheDataset, - DataLoader, - PatchDataset, - pad_list_data_collate, -) -from monai.data.meta_obj import set_track_meta -from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.transforms import ( - AsDiscrete, Compose, EnsureChannelFirstd, EnsureTyped, LoadImaged, Orientationd, - RandFlipd, - RandRotate90d, - RandShiftIntensityd, - RandSpatialCropSamplesd, - ScaleIntensityRanged, - SpatialPadd, ) -from monai.utils import set_determinism # local from napari_cellseg3d import config, utils -from napari_cellseg3d.code_models.models.wnet.model import WNet -from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss -from napari_cellseg3d.code_models.worker_training import TrainingWorkerBase +from napari_cellseg3d.code_models.worker_training import WNetTrainingWorker from napari_cellseg3d.code_models.workers_utils import ( PRETRAINED_WEIGHTS_DIR, ) +if TYPE_CHECKING: + from monai.data import DataLoader + logger = utils.LOGGER VERBOSE_SCHEDULER = True logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") -try: - import wandb - WANDB_INSTALLED = True -except ImportError: - logger.warning( - "wandb not installed, wandb config will not be taken into account", - stacklevel=1, - ) - WANDB_INSTALLED = False +class LogFixture: + """Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers. + + This allows to redirect the output of the workers to stdout instead of a specialized widget. + """ -# TODO subclass to reduce code duplication + def __init__(self): + """Creates a LogFixture object.""" + super(LogFixture, self).__init__() + def print_and_log(self, text, printing=None): + """Prints and logs text.""" + print(text) -class WNetTrainingWorkerColab(TrainingWorkerBase): + def warn(self, warning): + """Logs warning.""" + logger.warning(warning) + + def error(self, e): + """Logs error.""" + raise (e) + + +class WNetTrainingWorkerColab(WNetTrainingWorker): """A custom worker to run WNet (unsupervised) training jobs in. Inherits from :py:class:`napari.qt.threading.GeneratorWorker` via :py:class:`TrainingWorkerBase`. @@ -75,8 +71,8 @@ def __init__( worker_config: worker configuration wandb_config: optional wandb configuration """ - super().__init__() - self.config = worker_config + super().__init__(worker_config) + super().__init__(worker_config) self.wandb_config = ( wandb_config if wandb_config is not None else config.WandBConfig() ) @@ -96,89 +92,11 @@ def __init__( self.eval_dataloader: DataLoader = None self.data_shape = None - def log(self, text): - """Log a message to the logger and to wandb if installed.""" - logger.info(text) - - def get_patch_dataset(self, train_transforms): - """Creates a Dataset from the original data using the tifffile library. - - Args: - train_transforms (Compose): The transforms to apply to the data - - Returns: - (tuple): A tuple containing the shape of the data and the dataset - """ - patch_func = Compose( - [ - LoadImaged(keys=["image"], image_only=True), - EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"), - RandSpatialCropSamplesd( - keys=["image"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=self.config.num_samples, - ), - Orientationd(keys=["image"], axcodes="PLI"), - SpatialPadd( - keys=["image"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), - ), - EnsureTyped(keys=["image"]), - ] - ) - dataset = PatchDataset( - data=self.config.train_data_dict, - samples_per_image=self.config.num_samples, - patch_func=patch_func, - transform=train_transforms, - ) - - return self.config.sample_size, dataset - - def get_dataset_eval(self, eval_dataset_dict): - """Creates a Dataset applying some transforms/augmentation on the data using the MONAI library.""" - eval_transforms = Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd( - keys=["image", "label"], channel_dim="no_channel" - ), - # RandSpatialCropSamplesd( - # keys=["image", "label"], - # roi_size=( - # self.config.sample_size - # ), # multiply by axis_stretch_factor if anisotropy - # # max_roi_size=(120, 120, 120), - # random_size=False, - # num_samples=self.config.num_samples, - # ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - # SpatialPadd( - # keys=["image", "label"], - # spatial_size=( - # utils.get_padding_dim(self.config.sample_size) - # ), - # ), - EnsureTyped(keys=["image", "label"]), - ] - ) - - return CacheDataset( - data=eval_dataset_dict, - transform=eval_transforms, - ) - def get_dataset(self, train_transforms): """Creates a Dataset applying some transforms/augmentation on the data using the MONAI library. Args: - train_transforms (Compose): The transforms to apply to the data + train_transforms (monai.transforms.Compose): The transforms to apply to the data Returns: (tuple): A tuple containing the shape of the data and the dataset @@ -188,16 +106,25 @@ def get_dataset(self, train_transforms): first_volume = LoadImaged(keys=["image"])(train_files[0]) first_volume_shape = first_volume["image"].shape + if len(first_volume_shape) != 3: + raise ValueError( + f"Expected 3D volumes, got {len(first_volume_shape)} dimensions" + ) + # Transforms to be applied to each volume load_single_images = Compose( [ LoadImaged(keys=["image"]), - EnsureChannelFirstd(keys=["image"]), - Orientationd(keys=["image"], axcodes="PLI"), - SpatialPadd( + EnsureChannelFirstd( keys=["image"], - spatial_size=(utils.get_padding_dim(first_volume_shape)), + channel_dim="no_channel", + strict_check=False, ), + Orientationd(keys=["image"], axcodes="PLI"), + # SpatialPadd( + # keys=["image"], + # spatial_size=(utils.get_padding_dim(first_volume_shape)), + # ), EnsureTyped(keys=["image"]), # RemapTensord(keys=["image"], new_min=0.0, new_max=100.0), ] @@ -211,512 +138,6 @@ def get_dataset(self, train_transforms): return first_volume_shape, dataset - def _get_data(self): - if self.config.do_augmentation: - train_transforms = Compose( - [ - ScaleIntensityRanged( - keys=["image"], - a_min=0, - a_max=2000, - b_min=0.0, - b_max=1.0, - clip=True, - ), - RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), - RandFlipd(keys=["image"], spatial_axis=[1], prob=0.5), - RandFlipd(keys=["image"], spatial_axis=[2], prob=0.5), - RandRotate90d(keys=["image"], prob=0.1, max_k=3), - EnsureTyped(keys=["image"]), - ] - ) - else: - train_transforms = EnsureTyped(keys=["image"]) - - if self.config.sampling: - logger.debug("Loading patch dataset") - (self.data_shape, dataset) = self.get_patch_dataset( - train_transforms - ) - else: - logger.debug("Loading volume dataset") - (self.data_shape, dataset) = self.get_dataset(train_transforms) - - logger.debug(f"Data shape : {self.data_shape}") - self.dataloader = DataLoader( - dataset, - batch_size=self.config.batch_size, - shuffle=True, - num_workers=self.config.num_workers, - collate_fn=pad_list_data_collate, - ) - - if self.config.eval_volume_dict is not None: - eval_dataset = self.get_dataset_eval(self.config.eval_volume_dict) - - self.eval_dataloader = DataLoader( - eval_dataset, - batch_size=self.config.batch_size, - shuffle=False, - num_workers=self.config.num_workers, - collate_fn=pad_list_data_collate, - ) - else: - self.eval_dataloader = None - return self.dataloader, self.eval_dataloader, self.data_shape - - def log_parameters(self): - """Log the parameters of the training.""" - self.log("*" * 20) - self.log("-- Parameters --") - self.log(f"Device: {self.config.device}") - self.log(f"Batch size: {self.config.batch_size}") - self.log(f"Epochs: {self.config.max_epochs}") - self.log(f"Learning rate: {self.config.learning_rate}") - self.log(f"Validation interval: {self.config.validation_interval}") - if self.config.weights_info.use_custom: - self.log(f"Custom weights: {self.config.weights_info.path}") - elif self.config.weights_info.use_pretrained: - self.log(f"Pretrained weights: {self.config.weights_info.path}") - if self.config.sampling: - self.log( - f"Using {self.config.num_samples} samples of size {self.config.sample_size}" - ) - if self.config.do_augmentation: - self.log("Using data augmentation") - ############## - self.log("-- Model --") - self.log(f"Using {self.config.num_classes} classes") - self.log(f"Weight decay: {self.config.weight_decay}") - self.log("* NCuts : ") - self.log(f"- Intensity sigma {self.config.intensity_sigma}") - self.log(f"- Spatial sigma {self.config.spatial_sigma}") - self.log(f"- Radius : {self.config.radius}") - self.log(f"* Reconstruction loss : {self.config.reconstruction_loss}") - self.log( - f"Weighted sum : {self.config.n_cuts_weight}*NCuts + {self.config.rec_loss_weight}*Reconstruction" - ) - ############## - self.log("-- Data --") - self.log("Training data :\n") - [ - self.log(f"{v}") - for d in self.config.train_data_dict - for k, v in d.items() - ] - if self.config.eval_volume_dict is not None: - self.log("\nValidation data :\n") - [ - self.log(f"{k}: {v}") - for d in self.config.eval_volume_dict - for k, v in d.items() - ] - self.log("*" * 20) - - def train( - self, provided_model=None, provided_optimizer=None, provided_loss=None - ): - """Train the model.""" - try: - if self.config is None: - self.config = config.WNetTrainingWorkerConfig() - ############## - # disable metadata tracking in MONAI - set_track_meta(False) - ############## - if WANDB_INSTALLED: - config_dict = self.config.__dict__ - logger.debug(f"wandb config : {config_dict}") - wandb.init( - config=config_dict, - project="CellSeg3D (Colab)", - name=f"{self.config.model_info.name} training - {utils.get_date_time()}", - mode=self.wandb_config.mode, - tags=["WNet3D", "Colab"], - ) - - set_determinism(seed=self.config.deterministic_config.seed) - torch.use_deterministic_algorithms(True, warn_only=True) - - device = self.config.device - - self.log_parameters() - self.log("Initializing training...") - self.log("- Getting the data") - - self._get_data() - - ################################################### - # Training the model # - ################################################### - self.log("- Getting the model") - # Initialize the model - model = ( - WNet( - in_channels=self.config.in_channels, - out_channels=self.config.out_channels, - num_classes=self.config.num_classes, - dropout=self.config.dropout, - ) - if provided_model is None - else provided_model - ) - model.to(device) - - if self.config.use_clipping: - for p in model.parameters(): - p.register_hook( - lambda grad: torch.clamp( - grad, - min=-self.config.clipping, - max=self.config.clipping, - ) - ) - - if WANDB_INSTALLED: - wandb.watch(model, log_freq=100) - - if self.config.weights_info.use_custom: - if self.config.weights_info.use_pretrained: - weights_file = "wnet.pth" - self.downloader.download_weights("WNet3D", weights_file) - weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) - self.config.weights_info.path = weights - else: - weights = str(Path(self.config.weights_info.path)) - - try: - model.load_state_dict( - torch.load( - weights, - map_location=self.config.device, - ), - strict=True, - ) - except RuntimeError as e: - logger.error(f"Error when loading weights : {e}") - logger.exception(e) - warn = ( - "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" - "the model will be trained from random weights" - ) - self.log(warn) - self.warn(warn) - self._weight_error = True - else: - self.log("Model will be trained from scratch") - self.log("- Getting the optimizer") - # Initialize the optimizers - if self.config.weight_decay is not None: - decay = self.config.weight_decay - optimizer = torch.optim.Adam( - model.parameters(), - lr=self.config.learning_rate, - weight_decay=decay, - ) - else: - optimizer = torch.optim.Adam( - model.parameters(), lr=self.config.learning_rate - ) - if provided_optimizer is not None: - optimizer = provided_optimizer - self.log("- Getting the loss functions") - # Initialize the Ncuts loss function - criterionE = SoftNCutsLoss( - data_shape=self.data_shape, - device=device, - intensity_sigma=self.config.intensity_sigma, - spatial_sigma=self.config.spatial_sigma, - radius=self.config.radius, - ) - - if self.config.reconstruction_loss == "MSE": - criterionW = nn.MSELoss() - elif self.config.reconstruction_loss == "BCE": - criterionW = nn.BCELoss() - else: - raise ValueError( - f"Unknown reconstruction loss : {self.config.reconstruction_loss} not supported" - ) - - model.train() - - self.log("Ready") - self.log("Training the model") - self.log("*" * 20) - - # Train the model - for epoch in range(self.config.max_epochs): - self.log(f"Epoch {epoch + 1} of {self.config.max_epochs}") - - epoch_ncuts_loss = 0 - epoch_rec_loss = 0 - epoch_loss = 0 - - for _i, batch in enumerate(self.dataloader): - # raise NotImplementedError("testing") - image_batch = batch["image"].to(device) - # Normalize the image - for i in range(image_batch.shape[0]): - for j in range(image_batch.shape[1]): - image_batch[i, j] = self.normalize_function( - image_batch[i, j] - ) - - # Forward pass - enc, dec = model(image_batch) - # Compute the Ncuts loss - Ncuts = criterionE(enc, image_batch) - - epoch_ncuts_loss += Ncuts.item() - if WANDB_INSTALLED: - wandb.log({"Train/Ncuts loss": Ncuts.item()}) - - # Compute the reconstruction loss - if isinstance(criterionW, nn.MSELoss): - reconstruction_loss = criterionW(dec, image_batch) - elif isinstance(criterionW, nn.BCELoss): - reconstruction_loss = criterionW( - torch.sigmoid(dec), - utils.remap_image(image_batch, new_max=1), - ) - - epoch_rec_loss += reconstruction_loss.item() - if WANDB_INSTALLED: - wandb.log( - { - "Train/Reconstruction loss": reconstruction_loss.item() - } - ) - - # Backward pass for the reconstruction loss - optimizer.zero_grad() - alpha = self.config.n_cuts_weight - beta = self.config.rec_loss_weight - - loss = alpha * Ncuts + beta * reconstruction_loss - if provided_loss is not None: - loss = provided_loss - epoch_loss += loss.item() - - if WANDB_INSTALLED: - wandb.log( - {"Train/Weighted sum of losses": loss.item()} - ) - - loss.backward(loss) - optimizer.step() - yield epoch_loss - - self.ncuts_losses.append( - epoch_ncuts_loss / len(self.dataloader) - ) - self.rec_losses.append(epoch_rec_loss / len(self.dataloader)) - self.total_losses.append(epoch_loss / len(self.dataloader)) - - if WANDB_INSTALLED: - wandb.log({"Ncuts loss for epoch": self.ncuts_losses[-1]}) - wandb.log( - {"Reconstruction loss for epoch": self.rec_losses[-1]} - ) - wandb.log( - {"Sum of losses for epoch": self.total_losses[-1]} - ) - wandb.log( - { - "LR/Model learning rate": optimizer.param_groups[ - 0 - ]["lr"] - } - ) - - self.log(f"Ncuts loss: {self.ncuts_losses[-1]:.5f}") - self.log(f"Reconstruction loss: {self.rec_losses[-1]:.5f}") - self.log( - f"Weighted sum of losses: {self.total_losses[-1]:.5f}" - ) - if epoch > 0: - self.log( - f"Ncuts loss difference: {self.ncuts_losses[-1] - self.ncuts_losses[-2]:.5f}" - ) - self.log( - f"Reconstruction loss difference: {self.rec_losses[-1] - self.rec_losses[-2]:.5f}" - ) - self.log( - f"Weighted sum of losses difference: {self.total_losses[-1] - self.total_losses[-2]:.5f}" - ) - - if ( - self.eval_dataloader is not None - and (epoch + 1) % self.config.validation_interval == 0 - ): - model.eval() - self.log("Validating...") - self.eval(model, epoch) # validation - - eta = ( - (time.time() - self.start_time) - * (self.config.max_epochs / (epoch + 1) - 1) - / 60 - ) - self.log(f"ETA: {eta:.1f} minutes") - self.log("-" * 20) - - # Save the model - if epoch % 5 == 0: - torch.save( - model.state_dict(), - self.config.results_path_folder + "/wnet_.pth", - ) - - self.log("Training finished") - if self.best_dice > -1: - best_dice_epoch = epoch - self.log( - f"Best dice metric : {self.best_dice} at epoch {best_dice_epoch}" - ) - - if WANDB_INSTALLED: - wandb.log( - { - "Validation/Best Dice": self.best_dice, - "Validation/Best Dice epoch": best_dice_epoch, - } - ) - - # Save the model - self.log( - f"Saving the model to: {self.config.results_path_folder}/wnet.pth", - ) - save_weights_path = self.config.results_path_folder + "/wnet.pth" - torch.save( - model.state_dict(), - save_weights_path, - ) - - if WANDB_INSTALLED and self.wandb_config.save_model_artifact: - model_artifact = wandb.Artifact( - "WNet3D", - type="model", - description="CellSeg3D WNet3D", - metadata=self.config.__dict__, - ) - model_artifact.add_file(save_weights_path) - wandb.log_artifact(model_artifact) - - except Exception as e: - msg = f"Training failed with exception: {e}" - self.log(msg) - self.raise_error(e, msg) - self.quit() - raise e - - def eval(self, model, _): - """Evaluate the model on the validation set.""" - with torch.no_grad(): - device = self.config.device - for _k, val_data in enumerate(self.eval_dataloader): - val_inputs, val_labels = ( - val_data["image"].to(device), - val_data["label"].to(device), - ) - - # normalize val_inputs across channels - for i in range(val_inputs.shape[0]): - for j in range(val_inputs.shape[1]): - val_inputs[i][j] = self.normalize_function( - val_inputs[i][j] - ) - logger.debug(f"Val inputs shape: {val_inputs.shape}") - val_outputs = sliding_window_inference( - val_inputs, - roi_size=[64, 64, 64], - sw_batch_size=1, - predictor=model.forward_encoder, - overlap=0.1, - mode="gaussian", - sigma_scale=0.01, - progress=True, - ) - val_decoder_outputs = sliding_window_inference( - val_outputs, - roi_size=[64, 64, 64], - sw_batch_size=1, - predictor=model.forward_decoder, - overlap=0.1, - mode="gaussian", - sigma_scale=0.01, - progress=True, - ) - val_outputs = AsDiscrete(threshold=0.5)(val_outputs) - logger.debug(f"Val outputs shape: {val_outputs.shape}") - logger.debug(f"Val labels shape: {val_labels.shape}") - logger.debug( - f"Val decoder outputs shape: {val_decoder_outputs.shape}" - ) - - # dices = [] - # Find in which channel the labels are (avoid background) - # for channel in range(val_outputs.shape[1]): - # dices.append( - # utils.dice_coeff( - # y_pred=val_outputs[ - # 0, channel : (channel + 1), :, :, : - # ], - # y_true=val_labels[0], - # ) - # ) - # logger.debug(f"DICE COEFF: {dices}") - # max_dice_channel = torch.argmax( - # torch.Tensor(dices) - # ) - # logger.debug( - # f"MAX DICE CHANNEL: {max_dice_channel}" - # ) - self.dice_metric( - y_pred=val_outputs, - # [ - # :, - # max_dice_channel : (max_dice_channel + 1), - # :, - # :, - # :, - # ], - y=val_labels, - ) - - # aggregate the final mean dice result - metric = self.dice_metric.aggregate().item() - self.dice_values.append(metric) - self.log(f"Validation Dice score: {metric:.3f}") - if self.best_dice < metric <= 1: - self.best_dice = metric - # save the best model - save_best_path = self.config.results_path_folder - # save_best_path.mkdir(parents=True, exist_ok=True) - save_best_name = "wnet" - save_path = ( - str(Path(save_best_path) / save_best_name) - + "_best_metric.pth" - ) - self.log(f"Saving new best model to {save_path}") - torch.save(model.state_dict(), save_path) - - if WANDB_INSTALLED: - # log validation dice score for each validation round - wandb.log({"Validation/Dice metric": metric}) - - self.dice_metric.reset() - - val_decoder_outputs = None - del val_decoder_outputs - val_outputs = None - del val_outputs - val_labels = None - del val_labels - val_inputs = None - del val_inputs - def get_colab_worker( worker_config: config.WNetTrainingWorkerConfig, @@ -728,8 +149,13 @@ def get_colab_worker( worker_config (config.WNetTrainingWorkerConfig): config for the training worker wandb_config (config.WandBConfig): config for wandb """ - worker = WNetTrainingWorkerColab(worker_config) - worker.wandb_config = wandb_config + log = LogFixture() + worker = WNetTrainingWorkerColab(worker_config, wandb_config) + + worker.log_signal.connect(log.print_and_log) + worker.warn_signal.connect(log.warn) + worker.error_signal.connect(log.error) + return worker diff --git a/notebooks/Colab_WNet3D_training.ipynb b/notebooks/Colab_WNet3D_training.ipynb index 0c5ed172..334be037 100644 --- a/notebooks/Colab_WNet3D_training.ipynb +++ b/notebooks/Colab_WNet3D_training.ipynb @@ -1,498 +1,495 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "source": [ - "# **WNet3D: self-supervised 3D cell segmentation**\n", - "\n", - "---\n", - "\n", - "This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).\n", - "\n", - "- 💜 The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software." - ], - "metadata": { - "id": "BTUVNXX7R3Go" - } - }, - { - "cell_type": "markdown", - "source": [ - "#**1. Installing dependencies**\n", - "---" - ], - "metadata": { - "id": "zmVCksV0EfVT" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "td_vf_pneSak" - }, - "outputs": [], - "source": [ - "#@markdown ##Play to install WNet dependencies\n", - "!pip install napari-cellseg3d" - ] - }, - { - "cell_type": "markdown", - "source": [ - "##**1.2 Load key dependencies**\n", - "---" - ], - "metadata": { - "id": "nqctRognFGDT" - } - }, - { - "cell_type": "code", - "source": [ - "# @title\n", - "from pathlib import Path\n", - "from napari_cellseg3d.dev_scripts import colab_training as c\n", - "from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo, PRETRAINED_WEIGHTS_DIR" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wOOhJjkxjXz-", - "outputId": "8f94416d-a482-4ec6-f980-a728e908d90d" - }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:napari_cellseg3d.utils:wandb not installed, wandb config will not be taken into account\n", - "WARNING:napari_cellseg3d.utils:wandb not installed, wandb config will not be taken into account\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## (optional) **1.3 Initialize Weights & Biases integration **\n", - "---\n", - "If you wish to utilize Weights & Biases (WandB) for monitoring and logging your training session, execute the cell below.\n", - "To enable it, just input your API key in the space provided." - ], - "metadata": { - "id": "Ax-vJAWRwIKi" - } - }, - { - "cell_type": "code", - "source": [ - "!pip install -q wandb\n", - "import wandb\n", - "wandb.login()" - ], - "metadata": { - "id": "QNgC3awjwb7G" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **2. Complete the Colab session**\n", - "---\n" - ], - "metadata": { - "id": "Zi9gRBHFFyX-" - } - }, - { - "cell_type": "markdown", - "source": [ - "\n", - "## **2.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:\n", - "\n", - "Navigate to Runtime and select Change the Runtime type.\n", - "\n", - "For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).\n", - "\n", - "Under Accelerator, choose GPU (Graphics Processing Unit).\n" - ], - "metadata": { - "id": "zSU-LYTfFnvF" - } - }, - { - "cell_type": "code", - "source": [ - "#@markdown ##Execute the cell below to verify if GPU access is available.\n", - "\n", - "import torch\n", - "if not torch.cuda.is_available():\n", - " print('You do not have GPU access.')\n", - " print('Did you change your runtime?')\n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ie7bXiMgFtPH", - "outputId": "3276444c-5109-47b4-f507-ea9acaab15ad" - }, - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "You have GPU access\n", - "Fri May 3 17:19:13 2024 \n", - "+---------------------------------------------------------------------------------------+\n", - "| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n", - "|-----------------------------------------+----------------------+----------------------+\n", - "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", - "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", - "| | | MIG M. |\n", - "|=========================================+======================+======================|\n", - "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", - "| N/A 50C P8 10W / 70W | 3MiB / 15360MiB | 0% Default |\n", - "| | | N/A |\n", - "+-----------------------------------------+----------------------+----------------------+\n", - " \n", - "+---------------------------------------------------------------------------------------+\n", - "| Processes: |\n", - "| GPU GI CI PID Type Process name GPU Memory |\n", - "| ID ID Usage |\n", - "|=======================================================================================|\n", - "| No running processes found |\n", - "+---------------------------------------------------------------------------------------+\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## **2.2. Mount Google Drive**\n", - "---\n", - "To integrate this notebook with your personal data, save your data on Google Drive in accordance with the directory structures detailed in Section 0.\n", - "\n", - "1. **Run** the **cell** below and click on the provided link.\n", - "\n", - "2. Log in to your Google account and grant the necessary permissions by clicking 'Allow'.\n", - "\n", - "3. Copy the generated authorization code and paste it into the cell, then press 'Enter'. This grants Colab access to read and write data to your Google Drive.\n", - "\n", - "4. After completion, you can view your data in the notebook. Simply click the Files tab on the top left and select 'Refresh'." - ], - "metadata": { - "id": "X_bbk7RAF2yw" - } - }, - { - "cell_type": "code", - "source": [ - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "AsIARCablq1V", - "outputId": "77ffdbd1-4c89-4a56-e3da-7777a607a328" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Mounted at /content/gdrive\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "** If you cannot see your files, reactivate your session by connecting to your hosted runtime.**\n", - "\n", - "\n", - "\"Example
Connect to a hosted runtime.
" - ], - "metadata": { - "id": "r6FI22lkQLTv" - } - }, - { - "cell_type": "code", - "source": [ - "# @title\n", - "# import wandb\n", - "# wandb.login()" - ], - "metadata": { - "id": "EtsK08ECwlnJ" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ], - "metadata": { - "id": "IkOpxYjaGM0m" - } - }, - { - "cell_type": "markdown", - "source": [ - "## **3.1. Choosing parameters**\n", - "\n", - "---\n", - "\n", - "### **Paths to the training data and model**\n", - "\n", - "* **`training_source`** specifies the paths to the training data. They must be a single multipage TIF file each\n", - "\n", - "* **`model_path`** specifies the directory where the model checkpoints will be saved.\n", - "\n", - "**Tip:** To easily copy paths, navigate to the 'Files' tab, right-click on a folder or file, and choose 'Copy path'.\n", - "\n", - "### **Training parameters**\n", - "\n", - "* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. Default: 50\n", - "\n", - "* **`batchs_size`** is the number of image that will be bundled together at each training step. Default: 4\n", - "\n", - "* **`learning_rate`** is the step size of the update of the model's weight. Try decreasing it if the NCuts loss is unstable. Default: 2e-5\n", - "\n", - "* **`num_classes`** is the number of brightness clusters to segment the image in. Try raising it to 3 if you have artifacts or \"halos\" around your cells that have significantly different brightness. Default: 2\n", - "\n", - "* **`weight_decay`** is a regularization parameter used to prevent overfitting. Default: 0.01\n", - "\n", - "* **`validation_frequency`** is the frequency at which the provided evaluation data is used to estimate the model's performance.\n", - "\n", - "* **`intensity_sigma`** is the standard deviation of the feature similarity term. Default: 1\n", - "\n", - "* **`spatial_sigma`** is the standard deviation of the spatial proximity term. Default: 4\n", - "\n", - "* **`ncuts_radius`** is the radius for the NCuts loss computation, in pixels. Default: 2\n", - "\n", - "* **`rec_loss`** is the loss to use for the decoder. Can be Mean Square Error (MSE) or Binary Cross Entropy (BCE). Default : MSE\n", - "\n", - "* **`n_cuts_weight`** is the weight of the NCuts loss in the weighted sum for the backward pass. Default: 0.5\n", - "* **`rec_loss_weight`** is the weight of the reconstruction loss. Default: 0.005\n" - ], - "metadata": { - "id": "65FhTkYlGKRt" - } - }, - { - "cell_type": "code", - "source": [ - "#@markdown ###Path to the training data:\n", - "training_source = \"./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full\" #@param {type:\"string\"}\n", - "#@markdown ###Model name and path to model folder:\n", - "model_path = \"./gdrive/MyDrive/CELLSEG_BENCHMARK/WNET_TRAINING_RESULTS\" #@param {type:\"string\"}\n", - "#@markdown ---\n", - "#@markdown ###Perform validation on a test dataset\n", - "do_validation = False #@param {type:\"boolean\"}\n", - "#@markdown ###Path to evaluation data (optional, use if checked above):\n", - "eval_source = \"./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/vol/\" #@param {type:\"string\"}\n", - "eval_target = \"./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/lab/\" #@param {type:\"string\"}\n", - "#@markdown ---\n", - "#@markdown ###Training parameters\n", - "number_of_epochs = 50 #@param {type:\"number\"}\n", - "#@markdown ###Default advanced parameters\n", - "use_default_advanced_parameters = False #@param {type:\"boolean\"}\n", - "#@markdown If not, please change:\n", - "\n", - "#@markdown Training parameters:\n", - "batch_size = 4 #@param {type:\"number\"}\n", - "learning_rate = 2e-5 #@param {type:\"number\"}\n", - "num_classes = 2 #@param {type:\"number\"}\n", - "weight_decay = 0.01 #@param {type:\"number\"}\n", - "#@markdown Validation parameters:\n", - "validation_frequency = 2 #@param {type:\"number\"}\n", - "#@markdown SoftNCuts parameters:\n", - "intensity_sigma = 1.0 #@param {type:\"number\"}\n", - "spatial_sigma = 4.0 #@param {type:\"number\"}\n", - "ncuts_radius = 2 #@param {type:\"number\"}\n", - "#@markdown Reconstruction loss:\n", - "rec_loss = \"MSE\" #@param[\"MSE\", \"BCE\"]\n", - "#@markdown Weighted sum of losses:\n", - "n_cuts_weight = 0.5 #@param {type:\"number\"}\n", - "rec_loss_weight = 0.005 #@param {type:\"number\"}" - ], - "metadata": { - "cellView": "form", - "id": "tTSCC6ChGuuA" - }, - "execution_count": 7, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "id": "HtoIo5GcKIXX" - } - }, - { - "cell_type": "markdown", - "source": [ - "# **4. Train the network**\n", - "---\n", - "\n", - "Important Reminder: Google Colab imposes a maximum session time to prevent extended GPU usage, such as for data mining. Ensure your training duration stays under 12 hours. If your training is projected to exceed this limit, consider reducing the `number_of_epochs`." - ], - "metadata": { - "id": "arWhMU6aKsri" - } - }, - { - "cell_type": "markdown", - "source": [ - "## **4.1. Initialize the config**\n", - "---" - ], - "metadata": { - "id": "L59J90S_Kva3" - } + { + "cell_type": "markdown", + "metadata": { + "id": "BTUVNXX7R3Go" + }, + "source": [ + "# **WNet3D: self-supervised 3D cell segmentation**\n", + "\n", + "---\n", + "\n", + "This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).\n", + "\n", + "- 💜 The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zmVCksV0EfVT" + }, + "source": [ + "#**1. Installing dependencies**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "td_vf_pneSak" + }, + "outputs": [], + "source": [ + "#@markdown ##Play to install CellSeg3D and WNet3D dependencies:\n", + "!pip install -q napari-cellseg3d\n", + "print(\"Dependencies installed\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nqctRognFGDT" + }, + "source": [ + "##**1.2 Load key dependencies**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "wOOhJjkxjXz-", + "outputId": "8f94416d-a482-4ec6-f980-a728e908d90d" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "# @title\n", - "train_data_folder = Path(training_source)\n", - "results_path = Path(model_path)\n", - "results_path.mkdir(exist_ok=True)\n", - "eval_image_folder = Path(eval_source)\n", - "eval_label_folder = Path(eval_target)\n", - "\n", - "eval_dict = c.create_eval_dataset_dict(\n", - " eval_image_folder,\n", - " eval_label_folder,\n", - " ) if do_validation else None\n", - "\n", - "try:\n", - " import wandb\n", - " WANDB_INSTALLED = True\n", - "except ImportError:\n", - " WANDB_INSTALLED = False\n", - "\n", - "\n", - "train_config = WNetTrainingWorkerConfig(\n", - " device=\"cuda:0\",\n", - " max_epochs=number_of_epochs,\n", - " learning_rate=2e-5,\n", - " validation_interval=2,\n", - " batch_size=4,\n", - " num_workers=2,\n", - " weights_info=WeightsInfo(),\n", - " results_path_folder=str(results_path),\n", - " train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),\n", - " eval_volume_dict=eval_dict,\n", - ") if use_default_advanced_parameters else WNetTrainingWorkerConfig(\n", - " device=\"cuda:0\",\n", - " max_epochs=number_of_epochs,\n", - " learning_rate=learning_rate,\n", - " validation_interval=validation_frequency,\n", - " batch_size=batch_size,\n", - " num_workers=2,\n", - " weights_info=WeightsInfo(),\n", - " results_path_folder=str(results_path),\n", - " train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),\n", - " eval_volume_dict=eval_dict,\n", - " # advanced\n", - " num_classes=num_classes,\n", - " weight_decay=weight_decay,\n", - " intensity_sigma=intensity_sigma,\n", - " spatial_sigma=spatial_sigma,\n", - " radius=ncuts_radius,\n", - " reconstruction_loss=rec_loss,\n", - " n_cuts_weight=n_cuts_weight,\n", - " rec_loss_weight=rec_loss_weight,\n", - ")\n", - "wandb_config = WandBConfig(\n", - " mode=\"disabled\" if not WANDB_INSTALLED else \"online\",\n", - " save_model_artifact=False,\n", - ")" - ], - "metadata": { - "id": "YOgLyUwPjvUX" - }, - "execution_count": null, - "outputs": [] + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:napari_cellseg3d.utils:wandb not installed, wandb config will not be taken into account\n", + "WARNING:napari_cellseg3d.utils:wandb not installed, wandb config will not be taken into account\n" + ] + } + ], + "source": [ + "# @title\n", + "from pathlib import Path\n", + "from napari_cellseg3d.dev_scripts import colab_training as c\n", + "from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo, PRETRAINED_WEIGHTS_DIR" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ax-vJAWRwIKi" + }, + "source": [ + "## Optional - *1.3 Initialize Weights & Biases integration*\n", + "---\n", + "If you wish to utilize Weights & Biases (WandB) for monitoring and logging your training session, uncomment and execute the cell below.\n", + "To enable it, just input your API key in the space provided." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QNgC3awjwb7G" + }, + "outputs": [], + "source": [ + "# !pip install -q wandb\n", + "# import wandb\n", + "# wandb.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zi9gRBHFFyX-" + }, + "source": [ + "# **2. Complete the Colab session**\n", + "---\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zSU-LYTfFnvF" + }, + "source": [ + "\n", + "## **2.1. Check for GPU access**\n", + "---\n", + "\n", + "By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:\n", + "\n", + "Navigate to Runtime and select Change the Runtime type.\n", + "\n", + "For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).\n", + "\n", + "Under Accelerator, choose GPU (Graphics Processing Unit).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "Ie7bXiMgFtPH", + "outputId": "3276444c-5109-47b4-f507-ea9acaab15ad" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "## **4.2. Start training**\n", - "---" - ], - "metadata": { - "id": "idowGpeQPIm2" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "You have GPU access\n", + "Fri May 3 17:19:13 2024 \n", + "+---------------------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n", + "|-----------------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|=========================================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 50C P8 10W / 70W | 3MiB / 15360MiB | 0% Default |\n", + "| | | N/A |\n", + "+-----------------------------------------+----------------------+----------------------+\n", + " \n", + "+---------------------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=======================================================================================|\n", + "| No running processes found |\n", + "+---------------------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "#@markdown ##Execute the cell below to verify if GPU access is available.\n", + "\n", + "import torch\n", + "if not torch.cuda.is_available():\n", + " print('You do not have GPU access.')\n", + " print('Did you change your runtime?')\n", + " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", + " print('Expect slow performance. To access GPU try reconnecting later')\n", + "\n", + "else:\n", + " print('You have GPU access')\n", + " !nvidia-smi\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X_bbk7RAF2yw" + }, + "source": [ + "## **2.2. Mount Google Drive**\n", + "---\n", + "To integrate this notebook with your personal data, save your data on Google Drive in accordance with the directory structures detailed in Section 0.\n", + "\n", + "1. **Run** the **cell** below and click on the provided link.\n", + "\n", + "2. Log in to your Google account and grant the necessary permissions by clicking 'Allow'.\n", + "\n", + "3. Copy the generated authorization code and paste it into the cell, then press 'Enter'. This grants Colab access to read and write data to your Google Drive.\n", + "\n", + "4. After completion, you can view your data in the notebook. Simply click the Files tab on the top left and select 'Refresh'." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "AsIARCablq1V", + "outputId": "77ffdbd1-4c89-4a56-e3da-7777a607a328" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "# @title\n", - "worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)\n", - "for epoch_loss in worker.train():\n", - " continue" - ], - "metadata": { - "id": "OXxKZhGMqguz" - }, - "execution_count": null, - "outputs": [] + "name": "stdout", + "output_type": "stream", + "text": [ + "Mounted at /content/gdrive\n" + ] } - ] + ], + "source": [ + "# mount user's Google Drive to Google Colab.\n", + "from google.colab import drive\n", + "drive.mount('/content/gdrive')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r6FI22lkQLTv" + }, + "source": [ + "** If you cannot see your files, reactivate your session by connecting to your hosted runtime.**\n", + "\n", + "\n", + "\"Example
Connect to a hosted runtime.
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IkOpxYjaGM0m" + }, + "source": [ + "# **3. Select your parameters and paths**\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "65FhTkYlGKRt" + }, + "source": [ + "## **3.1. Choosing parameters**\n", + "\n", + "---\n", + "\n", + "### **Paths to the training data and model**\n", + "\n", + "* **`training_source`** specifies the paths to the training data. They must be a single multipage TIF file each\n", + "\n", + "* **`model_save_path`** specifies the directory where the model checkpoints will be saved.\n", + "\n", + "**Tip:** To easily copy paths, navigate to the 'Files' tab, right-click on a folder or file, and choose 'Copy path'.\n", + "\n", + "### **Training parameters**\n", + "\n", + "* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. Default: 50\n", + "\n", + "* **`batchs_size`** is the number of image that will be bundled together at each training step. Default: 4\n", + "\n", + "* **`learning_rate`** is the step size of the update of the model's weight. Try decreasing it if the NCuts loss is unstable. Default: 2e-5\n", + "\n", + "* **`num_classes`** is the number of brightness clusters to segment the image in. Try raising it to 3 if you have artifacts or \"halos\" around your cells that have significantly different brightness. Default: 2\n", + "\n", + "* **`weight_decay`** is a regularization parameter used to prevent overfitting. Default: 0.01\n", + "\n", + "* **`validation_frequency`** is the frequency at which the provided evaluation data is used to estimate the model's performance.\n", + "\n", + "* **`intensity_sigma`** is the standard deviation of the feature similarity term. Default: 1\n", + "\n", + "* **`spatial_sigma`** is the standard deviation of the spatial proximity term. Default: 4\n", + "\n", + "* **`ncuts_radius`** is the radius for the NCuts loss computation, in pixels. Default: 2\n", + "\n", + "* **`rec_loss`** is the loss to use for the decoder. Can be Mean Square Error (MSE) or Binary Cross Entropy (BCE). Default : MSE\n", + "\n", + "* **`n_cuts_weight`** is the weight of the NCuts loss in the weighted sum for the backward pass. Default: 0.5\n", + "* **`rec_loss_weight`** is the weight of the reconstruction loss. Default: 0.005\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "tTSCC6ChGuuA" + }, + "outputs": [], + "source": [ + "#@markdown ###Path to the training data:\n", + "training_source = \"./gdrive/MyDrive/path/to/data\" #@param {type:\"string\"}\n", + "#@markdown ###Path to save the weights (make sure to have enough space in your drive):\n", + "model_save_path = \"./gdrive/MyDrive/WNET_TRAINING_RESULTS\" #@param {type:\"string\"}\n", + "#@markdown ---\n", + "#@markdown ###Perform validation on a test dataset (optional):\n", + "do_validation = False #@param {type:\"boolean\"}\n", + "#@markdown ###Path to evaluation data (optional, use if checked above):\n", + "eval_source = \"./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/vol/\" #@param {type:\"string\"}\n", + "eval_target = \"./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/lab/\" #@param {type:\"string\"}\n", + "#@markdown ---\n", + "#@markdown ###Training parameters\n", + "number_of_epochs = 50 #@param {type:\"number\"}\n", + "#@markdown ###Default advanced parameters\n", + "use_default_advanced_parameters = False #@param {type:\"boolean\"}\n", + "#@markdown If not, please change:\n", + "\n", + "#@markdown Training parameters:\n", + "batch_size = 4 #@param {type:\"number\"}\n", + "learning_rate = 2e-5 #@param {type:\"number\"}\n", + "num_classes = 2 #@param {type:\"number\"}\n", + "weight_decay = 0.01 #@param {type:\"number\"}\n", + "#@markdown Validation parameters:\n", + "validation_frequency = 2 #@param {type:\"number\"}\n", + "#@markdown SoftNCuts parameters:\n", + "intensity_sigma = 1.0 #@param {type:\"number\"}\n", + "spatial_sigma = 4.0 #@param {type:\"number\"}\n", + "ncuts_radius = 2 #@param {type:\"number\"}\n", + "#@markdown Reconstruction loss:\n", + "rec_loss = \"MSE\" #@param[\"MSE\", \"BCE\"]\n", + "#@markdown Weighted sum of losses:\n", + "n_cuts_weight = 0.5 #@param {type:\"number\"}\n", + "rec_loss_weight = 0.005 #@param {type:\"number\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HtoIo5GcKIXX" + }, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "arWhMU6aKsri" + }, + "source": [ + "# **4. Train the network**\n", + "---\n", + "\n", + "Important Reminder: Google Colab imposes a maximum session time to prevent extended GPU usage, such as for data mining. Ensure your training duration stays under 12 hours. If your training is projected to exceed this limit, consider reducing the `number_of_epochs`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L59J90S_Kva3" + }, + "source": [ + "## **4.1. Initialize the config**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YOgLyUwPjvUX" + }, + "outputs": [], + "source": [ + "# @title\n", + "train_data_folder = Path(training_source)\n", + "results_path = Path(model_save_path)\n", + "results_path.mkdir(exist_ok=True)\n", + "eval_image_folder = Path(eval_source)\n", + "eval_label_folder = Path(eval_target)\n", + "\n", + "eval_dict = c.create_eval_dataset_dict(\n", + " eval_image_folder,\n", + " eval_label_folder,\n", + " ) if do_validation else None\n", + "\n", + "try:\n", + " import wandb\n", + " WANDB_INSTALLED = True\n", + "except ImportError:\n", + " WANDB_INSTALLED = False\n", + "\n", + "\n", + "train_config = WNetTrainingWorkerConfig(\n", + " device=\"cuda:0\",\n", + " max_epochs=number_of_epochs,\n", + " learning_rate=2e-5,\n", + " validation_interval=2,\n", + " batch_size=4,\n", + " num_workers=2,\n", + " weights_info=WeightsInfo(),\n", + " results_path_folder=str(results_path),\n", + " train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),\n", + " eval_volume_dict=eval_dict,\n", + ") if use_default_advanced_parameters else WNetTrainingWorkerConfig(\n", + " device=\"cuda:0\",\n", + " max_epochs=number_of_epochs,\n", + " learning_rate=learning_rate,\n", + " validation_interval=validation_frequency,\n", + " batch_size=batch_size,\n", + " num_workers=2,\n", + " weights_info=WeightsInfo(),\n", + " results_path_folder=str(results_path),\n", + " train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),\n", + " eval_volume_dict=eval_dict,\n", + " # advanced\n", + " num_classes=num_classes,\n", + " weight_decay=weight_decay,\n", + " intensity_sigma=intensity_sigma,\n", + " spatial_sigma=spatial_sigma,\n", + " radius=ncuts_radius,\n", + " reconstruction_loss=rec_loss,\n", + " n_cuts_weight=n_cuts_weight,\n", + " rec_loss_weight=rec_loss_weight,\n", + ")\n", + "wandb_config = WandBConfig(\n", + " mode=\"disabled\" if not WANDB_INSTALLED else \"online\",\n", + " save_model_artifact=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "idowGpeQPIm2" + }, + "source": [ + "## **4.2. Start training**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OXxKZhGMqguz" + }, + "outputs": [], + "source": [ + "# @title\n", + "worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)\n", + "for epoch_loss in worker.train():\n", + " continue" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Once you have trained the model, you will have the weights as a .pth file" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/notebooks/Colab_inference_demo.ipynb b/notebooks/Colab_inference_demo.ipynb index 3e79717e..e5d3888e 100644 --- a/notebooks/Colab_inference_demo.ipynb +++ b/notebooks/Colab_inference_demo.ipynb @@ -7,7 +7,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -46,29 +46,16 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "bnFKu6uFAm-z", - "collapsed": true, - "outputId": "a52993ed-bfc1-4b44-973c-3f7da876e33a", - "colab": { - "base_uri": "https://localhost:8080/" - } + "collapsed": true }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ "#@markdown ##Install CellSeg3D and grab demo data\n", "!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git --branch main --single-branch ./CellSeg3D\n", - "!pip install napari-cellseg3d\n", - "!pip install pyClesperanto" + "!pip install napari-cellseg3d" ] }, { @@ -83,24 +70,11 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { - "id": "vzm75tE_Am-0", - "outputId": "81a95be8-fe48-4a5b-a64c-2f993772c418", - "colab": { - "base_uri": "https://localhost:8080/" - } + "id": "vzm75tE_Am-0" }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.10/dist-packages/pytools/persistent_dict.py:52: RecommendedHashNotFoundWarning: Unable to import recommended hash 'siphash24.siphash13', falling back to 'hashlib.sha256'. Run 'python3 -m pip install siphash24' to install the recommended hash.\n", - " warn(\"Unable to import recommended hash 'siphash24.siphash13', \"\n" - ] - } - ], + "outputs": [], "source": [ "# @title Load libraries\n", "import napari_cellseg3d\n", @@ -145,43 +119,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { - "id": "Fe8hNkOpAm-0", - "outputId": "3488c95a-b0d0-4557-d69f-1c89640cfaf3", - "colab": { - "base_uri": "https://localhost:8080/" - } + "id": "Fe8hNkOpAm-0" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "You have GPU access\n", - "Sun Dec 15 21:09:57 2024 \n", - "+---------------------------------------------------------------------------------------+\n", - "| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n", - "|-----------------------------------------+----------------------+----------------------+\n", - "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", - "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", - "| | | MIG M. |\n", - "|=========================================+======================+======================|\n", - "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", - "| N/A 46C P8 9W / 70W | 3MiB / 15360MiB | 0% Default |\n", - "| | | N/A |\n", - "+-----------------------------------------+----------------------+----------------------+\n", - " \n", - "+---------------------------------------------------------------------------------------+\n", - "| Processes: |\n", - "| GPU GI CI PID Type Process name GPU Memory |\n", - "| ID ID Usage |\n", - "|=======================================================================================|\n", - "| No running processes found |\n", - "+---------------------------------------------------------------------------------------+\n" - ] - } - ], + "outputs": [], "source": [ "#@markdown This cell verifies if GPU access is available.\n", "\n", @@ -212,11 +154,11 @@ "execution_count": 4, "metadata": { "id": "O0jLRpARAm-0", - "outputId": "fdf0800b-976a-47ef-d848-b7c45621f2c4", "colab": { "base_uri": "https://localhost:8080/", "height": 35 - } + }, + "outputId": "e4e8549c-7100-4c0c-bc30-505c0dfeb138" }, "outputs": [ { @@ -243,15 +185,66 @@ "cle.select_device(\"cupy\")" ] }, + { + "cell_type": "markdown", + "source": [ + "### Select the pretrained model" + ], + "metadata": { + "id": "b6vIW_oDlpok" + } + }, { "cell_type": "code", + "source": [ + "model_selection = \"SwinUNetR\" #@param [\"SwinUNetR\", \"WNet3D\", \"SegResNet\"]\n", + "print(f\"Selected model: {model_selection}\")" + ], + "metadata": { + "id": "5tkEI1q-loqB", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d41875da-3879-4158-8a0f-6330afe442af" + }, "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Selected model: SwinUNetR\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from napari_cellseg3d.config import ModelInfo\n", + "\n", + "model_info = ModelInfo(\n", + " name=model_selection,\n", + " model_input_size=64 if model_selection == \"SegResNet\" else [64,64,64],\n", + " num_classes=2,\n", + ")\n", + "inference_config.model_info = model_info" + ], + "metadata": { + "id": "aPFS4WTdmPo3" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "id": "hIEKoyEGAm-0", - "outputId": "c616aab6-a4e7-463b-a051-923bf85b8380", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "outputId": "2103baf6-8875-433b-8799-41e0d1f3c7f0" }, "outputs": [ { @@ -266,7 +259,7 @@ "Window overlap is 0.25\n", "Dataset loaded on cuda device\n", "--------------------\n", - "MODEL DIMS : 64\n", + "MODEL DIMS : [64, 64, 64]\n", "Model name : SwinUNetR\n", "Instantiating model...\n" ] @@ -276,8 +269,7 @@ "name": "stderr", "text": [ "monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward().\n", - "INFO:napari_cellseg3d.utils:********************\n", - "INFO:napari_cellseg3d.utils:Weight file SwinUNetR_latest.pth already exists, skipping download\n" + "INFO:napari_cellseg3d.utils:********************\n" ] }, { @@ -291,6 +283,8 @@ "output_type": "stream", "name": "stderr", "text": [ + "INFO:napari_cellseg3d.utils:Downloading the model from HuggingFace https://huggingface.co/C-Achard/cellseg3d/resolve/main/SwinUNetR_latest.tar.gz....\n", + "270729216B [00:10, 26012663.01B/s] \n", "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n" ] }, @@ -309,14 +303,14 @@ "Dataset loaded on cuda device\n", "--------------------\n", "Loading layer\n", - "2024-12-15 21:10:06,566 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'QuantileNormalization', transform is not lazy\n", - "2024-12-15 21:10:06,592 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'ToTensor', transform is not lazy\n", - "2024-12-15 21:10:06,595 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'EnsureType', transform is not lazy\n", + "2024-12-22 18:58:42,183 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'QuantileNormalization', transform is not lazy\n", + "2024-12-22 18:58:42,279 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'ToTensor', transform is not lazy\n", + "2024-12-22 18:58:42,290 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'EnsureType', transform is not lazy\n", "Done\n", "----------\n", "Inference started on layer...\n", "Post-processing...\n", - "Layer prediction saved as : volume_SwinUNetR_pred_1_2024_12_15_21_10_09\n" + "Layer prediction saved as : volume_SwinUNetR_pred_1_2024_12_22_18_58_48\n" ] } ], @@ -329,20 +323,20 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "id": "IFbmZ3_zAm-1", - "outputId": "b9abbb7e-40a7-407e-eca5-48142a608712", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "outputId": "bde6a6c5-f47f-4164-9e1c-3bf5a94dd00d" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "1it [00:00, 11.29it/s]\n", + "1it [00:00, 9.61it/s]\n", "clesperanto's cupy / CUDA backend is experimental. Please use it with care. The following functions are known to cause issues in the CUDA backend:\n", "affine_transform, apply_vector_field, create(uint64), create(int32), create(int64), resample, scale, spots_to_pointlist\n", "divide by zero encountered in scalar divide\n", @@ -354,6 +348,10 @@ "source": [ "# @title Post-process the result\n", "# @markdown This cell post-processes the result of the inference : thresholding, instance segmentation, and statistics.\n", + "\n", + "if model_selection == \"WNet3D\":\n", + " result[0].semantic_segmentation = result[0].semantic_segmentation[1]\n", + "\n", "instance_segmentation,stats = cs3d.post_processing(\n", " result[0].semantic_segmentation,\n", " config=post_process_config,\n", @@ -362,23 +360,23 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "id": "TMRiQ-m4Am-1", - "outputId": "96604c9f-cc6a-4b02-9c06-ada41acccb40", "colab": { "base_uri": "https://localhost:8080/", "height": 496, "referenced_widgets": [ - "14688e5b41f646449485e9aa4f724724", - "4a4f13871b914dc98610387a36dbac5b", - "d39fe022444544f296e5ada5630ca7e8", - "b0887fa3ae9842c2ac012521f0e4b954", - "15f5f78501fd408bb6477c02adff73fa", - "a8a98fa6693c4271abb49e5dc59f3e99", - "6bafd832de8f433fa7439b505e2fe922" + "7a72ee57e14c440bb2ce281da67e1311", + "2692114df7304a1cbc703e8bd0f848c3", + "697f7288fff64aefbee1a5c0d4894987", + "1e79fde882a44cd984a54a71c3337759", + "9cb58613b1a74eaeb285f6d2d77d567b", + "a1d487697e4b4ea6b897f380c2b112cc", + "10441f745a6f41cf8655b2fafbb8204f" ] - } + }, + "outputId": "2d819126-5478-4d98-a5e2-ecacb7872465" }, "outputs": [ { @@ -390,7 +388,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "14688e5b41f646449485e9aa4f724724" + "model_id": "7a72ee57e14c440bb2ce281da67e1311" } }, "metadata": {} @@ -416,11 +414,11 @@ " \n", "
update_plot
def update_plot(z)
/content/<ipython-input-7-245acde924e0><no docstring>
" + " border-bottom: 1px solid var(--colab-border-color);\">update_plot
def update_plot(z)
/content/<ipython-input-9-245acde924e0><no docstring>
" ] }, "metadata": {}, - "execution_count": 7 + "execution_count": 9 } ], "source": [ @@ -464,14 +462,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "id": "Tw5exJ5EAm-1", - "outputId": "5ddb1416-198e-4090-9e14-74ec214e8b68", "colab": { "base_uri": "https://localhost:8080/", "height": 424 - } + }, + "outputId": "3aa36115-0b22-495b-b7e9-3eba7c06069a" }, "outputs": [ { @@ -521,7 +519,7 @@ ], "text/html": [ "\n", - "
\n", + "
\n", "
\n", "