From 7a6563d7f083932acca4b6801a17661ea2a89a62 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Fri, 15 Nov 2024 16:52:51 -0600 Subject: [PATCH 01/10] Bug fix for training subsets of the folds. --- mist/main.py | 2 +- mist/scripts/run_all_entrypoint.py | 2 +- mist/scripts/train_entrypoint.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mist/main.py b/mist/main.py index 4fa858b..d4570cc 100755 --- a/mist/main.py +++ b/mist/main.py @@ -144,7 +144,7 @@ def main(arguments: argparse.Namespace) -> None: # Check if the number of folds is compatible with the number of folds # specified. if ( - np.max(mist_arguments.folds) + 1 < mist_arguments.nfolds or + np.max(mist_arguments.folds) + 1 > mist_arguments.nfolds or len(mist_arguments.folds) > mist_arguments.nfolds ): raise AssertionError( diff --git a/mist/scripts/run_all_entrypoint.py b/mist/scripts/run_all_entrypoint.py index 956f24d..4563b21 100755 --- a/mist/scripts/run_all_entrypoint.py +++ b/mist/scripts/run_all_entrypoint.py @@ -44,7 +44,7 @@ def run_all_entry(): # Check if the number of folds is compatible with the number of folds # specified. if ( - np.max(mist_arguments.folds) + 1 < mist_arguments.nfolds or + np.max(mist_arguments.folds) + 1 > mist_arguments.nfolds or len(mist_arguments.folds) > mist_arguments.nfolds ): raise AssertionError( diff --git a/mist/scripts/train_entrypoint.py b/mist/scripts/train_entrypoint.py index 9ce240e..dc27c8f 100755 --- a/mist/scripts/train_entrypoint.py +++ b/mist/scripts/train_entrypoint.py @@ -44,7 +44,7 @@ def train_entry(): # Check if the number of folds is compatible with the number of folds # specified. if ( - np.max(mist_arguments.folds) + 1 < mist_arguments.nfolds or + np.max(mist_arguments.folds) + 1 > mist_arguments.nfolds or len(mist_arguments.folds) > mist_arguments.nfolds ): raise AssertionError( @@ -71,3 +71,4 @@ def train_entry(): if __name__ == "__main__": train_entry() + From 2b318519b05fa967f7e70c3f758a31b4ad639c85 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Fri, 15 Nov 2024 16:53:27 -0600 Subject: [PATCH 02/10] Add exceptions that we'll use later. --- mist/runtime/exceptions.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/mist/runtime/exceptions.py b/mist/runtime/exceptions.py index a071841..947b695 100644 --- a/mist/runtime/exceptions.py +++ b/mist/runtime/exceptions.py @@ -3,7 +3,6 @@ class InsufficientValidationSetError(Exception): """Raised if validation set size is smaller than the number of GPUs.""" - def __init__(self, val_size: int, world_size: int) -> None: self.message = ( f"Validation set size of {val_size} is too small for {world_size} " @@ -11,3 +10,23 @@ def __init__(self, val_size: int, world_size: int) -> None: "number of GPUs." ) super().__init__(self.message) + + +class NaNLossError(Exception): + """Raised if the loss is NaN.""" + def __init__(self, epoch) -> None: + self.message = ( + f"Encountered NaN loss value in epoch {epoch}. Stopping training. " + "Consider using a different optimizer, reducing the learning rate, " + "or using gradient clipping." + ) + super().__init__(self.message) + + +class NoGPUsAvailableError(Exception): + """Raised if no GPU is available.""" + def __init__(self) -> None: + self.message = ( + "No GPU available. Please check your hardware configuration." + ) + super().__init__(self.message) \ No newline at end of file From 6ba261f1440403a3765c315b013525025a4309e0 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Fri, 15 Nov 2024 16:57:28 -0600 Subject: [PATCH 03/10] Readability update for exceptions.py. --- mist/runtime/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mist/runtime/exceptions.py b/mist/runtime/exceptions.py index 947b695..9d8e6b2 100644 --- a/mist/runtime/exceptions.py +++ b/mist/runtime/exceptions.py @@ -29,4 +29,4 @@ def __init__(self) -> None: self.message = ( "No GPU available. Please check your hardware configuration." ) - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message) From 9fbfe72093b0e8e575fe9675fff80e8a916e19a2 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:32:45 -0600 Subject: [PATCH 04/10] Update dependencies and move them to requirements.txt. --- pyproject.toml | 40 +++++++++++++++++----------------------- requirements.txt | 15 +++++++++++++++ 2 files changed, 32 insertions(+), 23 deletions(-) create mode 100644 requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 436268e..8338376 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [project] name = "mist-medical" -version = "0.1.5-beta" +version = "0.1.6-beta" requires-python = ">= 3.8" -description = "MIST is a simple, fully automated framework for 3D medical imaging segmentation." +description = "MIST is a simple and scalable end-to-end framework for medical imaging segmentation." readme = "README.md" license = {file = "LICENSE"} authors = [ @@ -25,33 +25,17 @@ keywords = [ "semantic segmentation", "medical image analysis", "medical image segmentation", - "nnU-Net", - "nnunet", + "nnUNet", "U-Net", - "unet", "vision transformers", - "UNETR", - "unetr" -] -dependencies = [ - "torch>=2.0.1", - "monai>=1.3.0", - "antspyx>=0.3.8", - "simpleitk>=2.2.1", - "numpy", - "pandas", - "rich", - "scipy", - "scikit-learn", - "scikit-image", - "nvidia-dali-cuda110", - "tensorboard", - "einops", + "Swin UNETR" ] +dynamic = ["dependencies"] [project.urls] homepage = "https://github.com/mist-medical/MIST" repository = "https://github.com/mist-medical/MIST" +documentation = "https://mist-medical.readthedocs.io/" [project.scripts] mist_run_all = "mist.scripts.run_all_entrypoint:run_all_entry" @@ -64,4 +48,14 @@ mist_postprocess = "mist.post_preds:mist_postprocess_entry" mist_convert_dataset = "mist.convert_to_mist:convert_to_mist_entry" [tool.codespell] -skip = ".git" \ No newline at end of file +skip = ".git" + +[build-system] +requires = ["setuptools", "wheel", "pip"] +build-backend = "setuptools.build_meta" + +[tool.dependencies-dev] +pytest = "^7.0" +flake8 = "^6.0" +black = "^23.0" +mypy = "^1.0" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cc81a52 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +--extra-index-url https://pypi.nvidia.com/ + +torch +monai +antspyx +simpleitk +numpy +pandas +rich +scipy +scikit-learn +scikit-image +nvidia-dali-cuda120 +tensorboard +einops From 46f762291df5cf75a6d9f5cb71f9b337c39c20c8 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:33:29 -0600 Subject: [PATCH 05/10] Update dockerfile with latest version of pytorch (2.5.1). --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 3777f03..feb18de 100755 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime +FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime # Set environment variables for non-interactive installation. ENV DEBIAN_FRONTEND=noninteractive From 2e46e3e6af4d17474cb1c986a9b93a7b5cd079c4 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:34:41 -0600 Subject: [PATCH 06/10] Add deep supervision wrapper to loss functions. --- mist/runtime/loss_functions.py | 82 +++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/mist/runtime/loss_functions.py b/mist/runtime/loss_functions.py index 7fe5d40..047ac1e 100755 --- a/mist/runtime/loss_functions.py +++ b/mist/runtime/loss_functions.py @@ -1,5 +1,5 @@ """Loss function implementations for training segmentation models.""" -from typing import Tuple +from typing import Tuple, Optional, Callable import argparse import torch @@ -9,6 +9,84 @@ from mist.runtime import loss_utils +class DeepSupervisionLoss(nn.Module): + """Loss function for deep supervision in segmentation tasks. + + This class calculates the loss for the main output and additional deep + supervision heads using a geometric weighting scheme. Deep supervision + provides intermediate outputs during training to guide the model's learning + at multiple stages. + + Attributes: + loss_fn: The base loss function to apply (e.g., Dice loss). + scaling_fn: A function to scale the loss for each supervision head. + Defaults to geometric scaling by 0.5 ** k, where k is the index. + """ + def __init__( + self, + loss_fn: nn.Module, + scaling_fn: Optional[Callable[[int], float]]=None + ): + super().__init__() + self.loss_fn = loss_fn + self.scaling_fn = scaling_fn or (lambda k: 0.5 ** k) + + def apply_loss(self, y_true, y_pred, alpha=None, dtm=None): + """Applies the configured loss function with appropriate arguments.""" + if dtm is not None: + return self.loss_fn(y_true, y_pred, dtm, alpha) + elif alpha is not None: + return self.loss_fn(y_true, y_pred, alpha) + return self.loss_fn(y_true, y_pred) + + def forward( + self, + y_true: torch.Tensor, + y_pred: torch.Tensor, + y_supervision: Optional[Tuple[torch.Tensor, ...]] = None, + alpha: Optional[float] = None, + dtm: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Computes the total loss, including contributions from deep supervision. + + Args: + y_true: Ground truth mask of shape (batch_size, 1, height, width, + depth). This mask is not one-hot encoded because of the we way + we construct the data loader. The one-hot encoding is applied + in the forward pass of the loss function. + y_pred: Predicted main output of shape (batch_size, num_classes, + height, width, depth). This is the main output of the network. + We assume that the predicted mask is the raw output of a network + that has not been passed through a softmax function. We apply + the softmax function in the forward pass of the loss function. + y_supervision (optional): Deep supervision outputs, each of shape + (batch_size, num_classes, height, width, depth). Like y_pred, + these are raw outputs of the network. We apply the softmax + function in the forward pass of the loss function. + alpha (optional): Balances region and boundary losses. This is a + hyperparameter that should be in the interval [0, 1]. + dtm (optional): Distance transform maps for boundary-based loss. + + Returns: + The total weighted loss. + """ + # Collect main prediction and deep supervision outputs. + _y_pred = [y_pred] + (list(y_supervision) if y_supervision else []) + + # Compute weighted loss. + losses = torch.stack( + [ + self.scaling_fn(k) * self.apply_loss(y_true, pred, alpha, dtm) + for k, pred in enumerate(_y_pred) + ] + ) + + # Normalize using the sum of the scaling factors. + normalization = sum(self.scaling_fn(k) for k in range(len(_y_pred))) + return losses.sum() / normalization + + class DiceLoss(nn.Module): """Soft Dice loss function for segmentation tasks. @@ -638,7 +716,7 @@ def get_loss(args: argparse.Namespace) -> nn.Module: """ if args.loss == "dice": return DiceLoss(exclude_background=args.exclude_background) - if args.loss == "dice_ce": + if args.loss == "dice-ce": return DiceCELoss(exclude_background=args.exclude_background) if args.loss == "bl": return BoundaryLoss(exclude_background=args.exclude_background) From 271109ddc17d2af70baf6edeef6bf6283caa65c1 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:36:28 -0600 Subject: [PATCH 07/10] Add validate every and after n epochs options to training. --- mist/runtime/args.py | 30 ++++---- mist/runtime/run.py | 174 +++++++++++++++---------------------------- 2 files changed, 76 insertions(+), 128 deletions(-) diff --git a/mist/runtime/args.py b/mist/runtime/args.py index 222a40f..f2bffac 100755 --- a/mist/runtime/args.py +++ b/mist/runtime/args.py @@ -208,6 +208,17 @@ def get_main_args(): type=int, help="Height, width, and depth of patch size" ) + parser.arg( + "--validate-every-n-epochs", + type=positive_int, + help="Validate every n epochs" + ) + parser.arg( + "--validate-after-n-epochs", + type=int, + default=1, + help="Start validation after n epochs. If -1, validate only on the last epoch" + ) parser.arg( "--max-patch-size", default=[256, 256, 256], @@ -227,12 +238,6 @@ def get_main_args(): default=3e-4, help="Learning rate" ) - parser.arg( - "--exp_decay", - type=positive_float, - default=0.9999, - help="Exponential decay factor" - ) parser.arg( "--lr-scheduler", type=str, @@ -241,17 +246,10 @@ def get_main_args(): "constant", "polynomial", "cosine", - "cosine_warm_restarts", - "exponential" + "cosine-warm-restarts", ], help="Learning rate scheduler" ) - parser.arg( - "--cosine-first-steps", - type=positive_int, - default=500, - help="Length of a cosine decay cycle in steps, only with cosine_annealing scheduler" - ) # Optimizer parameters. parser.arg( @@ -380,9 +378,9 @@ def get_main_args(): parser.arg( "--loss", type=str, - default="dice_ce", + default="dice-ce", choices=[ - "dice_ce", + "dice-ce", "dice", "bl", "hdl", diff --git a/mist/runtime/run.py b/mist/runtime/run.py index 5caf1c5..a70a46b 100755 --- a/mist/runtime/run.py +++ b/mist/runtime/run.py @@ -359,37 +359,28 @@ def train(self, rank: int, world_size: int) -> None: world_size=world_size ) - # Get steps per epoch if not given by user. - if self.mist_arguments.steps_per_epoch is None: - self.mist_arguments.steps_per_epoch = ( - len(train_images) // self.mist_arguments.batch_size + # Get steps per epoch, number of epochs, and validation parameters + # like validate after n epochs and validate every n epochs. + epochs_and_validation_params = ( + utils.get_epochs_and_validation_params( + mist_arguments=self.mist_arguments, + num_train_examples=len(train_images), + num_optimization_steps=constants.TOTAL_OPTIMIZATION_STEPS, + validate_every_n_steps=constants.VALIDATE_EVERY_N_STEPS, ) - else: - self.mist_arguments.steps_per_epoch = ( - self.mist_arguments.steps_per_epoch - ) - - # Get default number of epochs per fold. This is defined as - # 250000 / steps_per_epoch. - if self.mist_arguments.epochs is None: - self.mist_arguments.epochs = ( - constants.TOTAL_OPTIMIZATION_STEPS // - self.mist_arguments.steps_per_epoch - ) - - # Get number of epochs between each validation. We validate every - # 250 steps by default. + ) + steps_per_epoch = epochs_and_validation_params["steps_per_epoch"] + epochs = epochs_and_validation_params["epochs"] validate_every_n_epochs = ( - constants.VALIDATE_EVERY_N_STEPS // - self.mist_arguments.steps_per_epoch + epochs_and_validation_params["validate_every_n_epochs"] + ) + validate_after_n_epochs = ( + epochs_and_validation_params["validate_after_n_epochs"] ) - - # Ensure that we validate at most once per epoch. - validate_every_n_epochs = max(1, validate_every_n_epochs) # Initialize boundary loss weighting schedule. boundary_loss_weighting_schedule = utils.AlphaSchedule( - n_epochs=self.mist_arguments.epochs, + n_epochs=epochs, schedule=self.mist_arguments.boundary_loss_schedule, constant=self.mist_arguments.loss_schedule_constant, init_pause=self.mist_arguments.linear_schedule_pause, @@ -398,6 +389,9 @@ def train(self, rank: int, world_size: int) -> None: # Get loss function based on user arguments. loss_fn = loss_functions.get_loss(self.mist_arguments) + loss_fn_with_deep_supervision = ( + loss_functions.DeepSupervisionLoss(loss_fn) + ) # Make sure we are using/have DTMs for boundary-based loss # functions. @@ -456,7 +450,7 @@ def train(self, rank: int, world_size: int) -> None: # Get optimizer and lr scheduler optimizer = utils.get_optimizer(self.mist_arguments, model) learning_rate_scheduler = utils.get_lr_schedule( - self.mist_arguments, optimizer + self.mist_arguments, optimizer, epochs ) # Float16 inputs during the forward pass produce float16 gradients @@ -482,8 +476,8 @@ def train(self, rank: int, world_size: int) -> None: writer = SummaryWriter( os.path.join( self.mist_arguments.results, "logs", f"fold_{fold}" - ) ) + ) # Path and name for best model for this fold. best_model_name = os.path.join( @@ -520,72 +514,20 @@ def compute_loss() -> torch.Tensor: # Make predictions for the batch. output = model(image) # pylint: disable=cell-var-from-loop - # Compute loss for the batch. The inputs to the loss - # function depend on the loss function being used. - if self.mist_arguments.use_dtms: - # Use distance transform maps for boundary-based loss - # functions. - loss = loss_fn(label, output["prediction"], dtm, alpha) # pylint: disable=cell-var-from-loop - elif self.mist_arguments.loss in ["cldice"]: - # Use the alpha parameter to weight the cldice and - # dice with cross entropy loss functions. - loss = loss_fn(label, output["prediction"], alpha) # pylint: disable=cell-var-from-loop - else: - # Use only the image and label for other loss functions - # like dice with cross entropy. - loss = loss_fn(label, output["prediction"]) # pylint: disable=cell-var-from-loop - - # If deep supervision is enabled, compute the additional - # losses from the deep supervision heads. Deep supervision - # provides additional output layers that guide the model - # during training by supplying intermediate supervision - # signals at various stages of the model. - - # We scale the loss from each deep supervision head by a - # factor of (0.5 ** (k + 1)), where k is the index of the - # deep supervision head. This creates a geometric series - # that gives decreasing weight to deeper (later) supervision - # heads. The idea is to ensure that the loss from earlier - # heads (closer to the final output) contributes more to the - # total loss, while still incorporating the information from - # later heads. - - # After summing the losses from all deep supervision heads, - # we normalize the total loss using a correction factor - # (c_norm). This factor is derived from the sum of the - # geometric series (1 / (2 - 2 ** -n)), where n is the - # number of deep supervision heads. The normalization - # ensures that the total loss isn't biased or dominated by - # the deep supervision losses by making the loss a - # convex combination of the losses from all heads, including - # the main loss. - if self.mist_arguments.deep_supervision: - for k, p in enumerate(output["deep_supervision"]): - # Apply the loss function based on the model's - # configuration. If distance transform maps - # are used, pass them to the loss function. - if self.mist_arguments.use_dtms: - loss += 0.5 ** (k + 1) * loss_fn( # pylint: disable=cell-var-from-loop - label, p, dtm, alpha - ) - # If cldice loss is used, pass alpha to the loss - # function. - elif self.mist_arguments.loss in ["cldice"]: - loss += 0.5 ** (k + 1) * loss_fn( # pylint: disable=cell-var-from-loop - label, p, alpha - ) - # Otherwise, compute the loss normally. - else: - loss += 0.5 ** (k + 1) * loss_fn(label, p) # pylint: disable=cell-var-from-loop - - # Normalize the total loss from deep supervision heads - # using a correction factor to prevent it from - # dominating the main loss. - c_norm = 1 / (2 - 2 ** -( - len(output["deep_supervision"]) - ) - ) - loss *= c_norm + # Compute loss based on the output and ground truth label. + # Apply deep supervision if enabled. + y_supervision = ( + output["deep_supervision"] + if self.mist_arguments.deep_supervision + else None + ) + loss = loss_fn_with_deep_supervision( # pylint: disable=cell-var-from-loop + y_true=label, + y_pred=output["prediction"], + y_supervision=y_supervision, + alpha=alpha, + dtm=dtm, + ) # Check if Variational Autoencoder (VAE) regularization # is enabled. VAE regularization encourages the model to @@ -748,7 +690,7 @@ def val_step( return self.fixed_loss_functions["validation"](label, pred) # Train the model for the specified number of epochs. - for epoch in range(self.mist_arguments.epochs): + for epoch in range(epochs): # Make sure gradient tracking is on, and do a pass over the # training data. model.train(True) @@ -762,10 +704,10 @@ def val_step( with progress_bar.TrainProgressBar( epoch + 1, fold, - self.mist_arguments.epochs, - self.mist_arguments.steps_per_epoch + epochs, + steps_per_epoch, ) as pb: - for _ in range(self.mist_arguments.steps_per_epoch): + for _ in range(steps_per_epoch): # Get data from training loader. data = train_loader.next()[0] @@ -791,9 +733,6 @@ def val_step( else: loss = train_step(image, label, None, None) - # Update update the learning rate scheduler. - learning_rate_scheduler.step() - # Send all training losses to device 0 to add them. dist.reduce(loss, dst=0) @@ -803,12 +742,16 @@ def val_step( # Update the running loss for the progress bar. running_loss = running_loss_train(current_loss) - # Update the progress bar with the running loss. - pb.update(loss=running_loss) + # Update the progress bar with the running loss and + # learning rate. + pb.update( + loss=running_loss, + lr=optimizer.param_groups[0]["lr"] + ) else: # For all other processes, do not display the progress bar. # Repeat the training steps shown above for the other GPUs. - for _ in range(self.mist_arguments.steps_per_epoch): + for _ in range(steps_per_epoch): # Get data from training loader. data = train_loader.next()[0] @@ -830,22 +773,30 @@ def val_step( # Send the loss on the current GPU to device 0. dist.reduce(loss, dst=0) + # Update the learning rate scheduler. + learning_rate_scheduler.step() + # Wait for all processes to finish the epoch. dist.barrier() # Start validation. We don't need gradients on to do reporting. - # Only validate every validate_every_n_epochs epochs. - if ( - epoch % validate_every_n_epochs == 0 or - epoch == self.mist_arguments.epochs - 1 - ): + # Only validate on the first and last epochs or periodically + # after validate_after_n_epochs. + validate = ( + epoch == 0 or epoch == epochs - 1 or + ( + epoch >= validate_after_n_epochs and + epoch % validate_every_n_epochs == 0 + ) + ) + if validate: model.eval() with torch.no_grad(): # Only log metrics on first process (i.e., rank 0). if rank == 0: with progress_bar.ValidationProgressBar( val_steps - ) as pb: + ) as val_pb: for _ in range(val_steps): # Get data from validation loader. data = validation_loader.next()[0] @@ -860,8 +811,7 @@ def val_step( # Average the loss across all GPUs. current_val_loss = ( - val_loss.item() / - world_size + val_loss.item() / world_size ) # Update the running loss for the progress @@ -872,7 +822,7 @@ def val_step( # Update the progress bar with the running # loss. - pb.update(loss=running_val_loss) + val_pb.update(loss=running_val_loss) # Check if validation loss is lower than the current # best validation loss. If so, save the model. From e235c2f68a80d6b2fe23d4a074e3e8555cc10dcb Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:37:45 -0600 Subject: [PATCH 08/10] Add get_epochs_and_validation_params to utils for improved readability in run.py. Minor improvements to optimizers for AMP. --- mist/runtime/utils.py | 104 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 11 deletions(-) diff --git a/mist/runtime/utils.py b/mist/runtime/utils.py index 0fb05bb..c2e9b22 100755 --- a/mist/runtime/utils.py +++ b/mist/runtime/utils.py @@ -140,6 +140,7 @@ def set_warning_levels() -> None: warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=RuntimeWarning) warnings.simplefilter(action="ignore", category=UserWarning) + warnings.simplefilter(action="ignore", category=DeprecationWarning) def create_empty_dir(path: str) -> None: @@ -413,13 +414,15 @@ def convert_dict_to_df(patients: Dict[str, Dict[str, str]]) -> pd.DataFrame: def get_lr_schedule( mist_arguments: argparse.Namespace, - optimizer: torch.optim.Optimizer # type: ignore + optimizer: torch.optim.Optimizer, # type: ignore + epochs: int, ) -> torch.optim.lr_scheduler.LRScheduler: """Get learning rate schedule based on user input. Args: mist_arguments: Command line arguments. optimizer: Optimizer for which the learning rate schedule is created. + epochs: Number of epochs for training. Returns: lr_scheduler: Learning rate scheduler. @@ -434,23 +437,19 @@ def get_lr_schedule( if mist_arguments.lr_scheduler == "polynomial": return torch.optim.lr_scheduler.PolynomialLR( optimizer, - total_iters=mist_arguments.steps_per_epoch * mist_arguments.epochs, + total_iters=epochs, power=0.9 ) - if mist_arguments.lr_scheduler == "cosine_warm_restarts": + if mist_arguments.lr_scheduler == "cosine-warm-restarts": return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, - T_0=mist_arguments.cosine_first_steps, + T_0=int(np.ceil(0.1 * epochs)), T_mult=2 ) if mist_arguments.lr_scheduler == "cosine": return torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, - T_max=mist_arguments.steps_per_epoch * mist_arguments.epochs - ) - if mist_arguments.lr_scheduler == "exponential": - return torch.optim.lr_scheduler.ExponentialLR( - optimizer, gamma=mist_arguments.exp_decay + T_max=epochs ) raise ValueError( "Received invalid learning rate scheduler type " @@ -476,6 +475,8 @@ def get_optimizer( optimizer, add a new if statement and add the corresponding optimizer name to runtime/args.py. """ + # Increase epsilon for AMP to avoid NaNs. + eps = 1e-4 if mist_arguments.amp else 1e-8 if mist_arguments.optimizer == "sgd": return torch.optim.SGD( # type: ignore params=model.parameters(), @@ -484,11 +485,15 @@ def get_optimizer( ) if mist_arguments.optimizer == "adam": return torch.optim.Adam( # type: ignore - params=model.parameters(), lr=mist_arguments.learning_rate + params=model.parameters(), + lr=mist_arguments.learning_rate, + eps=eps, ) if mist_arguments.optimizer == "adamw": return torch.optim.AdamW( # type: ignore - params=model.parameters(), lr=mist_arguments.learning_rate + params=model.parameters(), + lr=mist_arguments.learning_rate, + eps=eps, ) raise ValueError( f"Received invalid optimizer type {mist_arguments.optimizer}." @@ -1309,3 +1314,80 @@ def __call__(self, epoch: int) -> float: if self.schedule == "cosine": return float(self.cosine(epoch)) raise ValueError(f"Received invalid schedule type {self.schedule}.") + + +def get_epochs_and_validation_params( + mist_arguments: argparse.Namespace, + num_train_examples: int, + num_optimization_steps: int, + validate_every_n_steps: int, +) -> Dict[str, int]: + """Get number of epochs and validation parameters based on user input. + + Args: + mist_arguments: Command line arguments. + num_train_examples: Number of training examples. + num_optimization_steps: Number of optimization steps. + validate_every_n_steps: Number of steps between each validation. + + Returns: + Dictionary containing the following key-value pairs: + steps_per_epoch: Number of steps per epoch. + epochs: Number of epochs. + validate_every_n_epochs: Number of epochs between each validation. + validate_after_n_epochs: Number of epochs before starting + validation. + + Raises: + ValueError: If validate_after_n_epochs is greater than epochs or a + negative value other than -1. + """ + # Get steps per epoch if not given by user. + if mist_arguments.steps_per_epoch is None: + steps_per_epoch = num_train_examples // mist_arguments.batch_size + else: + steps_per_epoch = mist_arguments.steps_per_epoch + + # Get default number of epochs per fold. This is defined as + # 250000 / steps_per_epoch. + if mist_arguments.epochs is None: + epochs = num_optimization_steps // steps_per_epoch + else: + epochs = mist_arguments.epochs + + # Get number of epochs between each validation. We validate every + # 250 steps by default. + if mist_arguments.validate_every_n_epochs is None: + validate_every_n_epochs = validate_every_n_steps // steps_per_epoch + + # Ensure that we validate at most once per epoch. + validate_every_n_epochs = max(1, validate_every_n_epochs) + else: + validate_every_n_epochs = mist_arguments.validate_every_n_epochs + + # Get number of epochs before starting validation. By default, we start + # validation after the first epoch. Otherwise, we start validation after + # some user-specified number of epochs. + validate_after_n_epochs = mist_arguments.validate_after_n_epochs + if validate_after_n_epochs > epochs: + raise ValueError( + "validate_after_n_epochs must be less than or equal to epochs. Got " + f"validate_after_n_epochs = {validate_after_n_epochs} and epochs = " + f"{epochs}." + ) + + if validate_after_n_epochs < 0: + if validate_after_n_epochs != -1: + raise ValueError( + "The only valid negative value for validate_after_n_epochs is " + f"-1. Got {validate_after_n_epochs}." + ) + validate_after_n_epochs = epochs + + # Format output as dictionary. + return { + "steps_per_epoch": steps_per_epoch, + "epochs": epochs, + "validate_every_n_epochs": validate_every_n_epochs, + "validate_after_n_epochs": validate_after_n_epochs, + } From d48d89e622e4425b2060c3b97a2d325932bf51d7 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:38:18 -0600 Subject: [PATCH 09/10] Readability updates for loss_utils.py. --- mist/runtime/loss_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mist/runtime/loss_utils.py b/mist/runtime/loss_utils.py index ea63d1a..8eb9d77 100644 --- a/mist/runtime/loss_utils.py +++ b/mist/runtime/loss_utils.py @@ -127,4 +127,3 @@ def forward(self, img: torch.Tensor) -> torch.Tensor: The soft skeleton of the input tensor. """ return self.soft_skeletonize(img) - From 273094790238ddf4577c1ad870b8460dfeb47f3c Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 2 Dec 2024 22:38:53 -0600 Subject: [PATCH 10/10] Add learning rate to training progress bar. --- mist/runtime/progress_bar.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mist/runtime/progress_bar.py b/mist/runtime/progress_bar.py index db40453..b4b5069 100755 --- a/mist/runtime/progress_bar.py +++ b/mist/runtime/progress_bar.py @@ -1,3 +1,4 @@ +"""Progress bars for MIST training and validation loops.""" from rich.progress import ( BarColumn, TextColumn, @@ -5,7 +6,7 @@ MofNCompleteColumn, TimeElapsedColumn ) - +import numpy as np class TrainProgressBar(Progress): def __init__(self, current_epoch, fold, epochs, train_steps): @@ -20,16 +21,26 @@ def __init__(self, current_epoch, fold, epochs, train_steps): TimeElapsedColumn(), TextColumn("•"), TextColumn("{task.fields[loss]}"), + TextColumn("•"), + TextColumn("{task.fields[lr]}"), # Learning rate column ) + # Initialize tasks with loss and learning rate fields self.task = self.progress.add_task( - description="Training", + description="Training (loss)", total=train_steps, - loss=f"loss: " + loss=f"loss: ", + lr=f"lr: ", ) - def update(self, loss): - self.progress.update(self.task, advance=1, loss=f"loss: {loss:.4f}") + def update(self, loss, lr): + # Update loss task. + self.progress.update( + self.task, + advance=1, + loss=f"loss: {loss:.4f}", + lr=f"lr: {np.format_float_scientific(lr, precision=3)}", + ) def __enter__(self): self.progress.start()