Skip to content

Commit

Permalink
Colab notebooks update (#103)
Browse files Browse the repository at this point in the history
* Fix use of deprecated arg in colab training

* Refactor model save name path + comment wandb cell

* Update Colab_WNet3D_training.ipynb

* Improve logging in Colab

* Subclass WnetTraininWorker to avoid duplication

* Remove strict channel first

* Add missing channel_dim, remove strict_check=False

* Update worker_training.py

* Update worker_training.py

* Disable strict checks for channelfirstd

* Update worker_training.py

* Temp disable channel first

* Fix init of Colab worker

* Move issues with transforms to colab script + disable pad/channelfirst

* Enable ChannelFirst again

* Remove strict_check = False in original worker

Seems to be a Colab-specific issue

* Remove redundant code + Colab notebook tweaks

* Revert wandb check

* Update docs + Colab inference

* Update training_wnet.rst

* Update Colab_WNet3D_training.ipynb

* update / WIP

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* Update Colab_inference_demo.ipynb

* nearly final!

* exec

* final

---------

Co-authored-by: Mackenzie Mathis <[email protected]>
Co-authored-by: Mackenzie Mathis <[email protected]>
  • Loading branch information
3 people authored Dec 22, 2024
1 parent bb806f0 commit 8c6c306
Show file tree
Hide file tree
Showing 5 changed files with 689 additions and 1,262 deletions.
21 changes: 9 additions & 12 deletions docs/source/guides/training_wnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,21 @@ The WNet3D **does not require a large amount of data to train**, but **choosing

You may find below some guidelines, based on our own data and testing.

The WNet3D is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background.

The WNet3D is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background.
The WNet3D is a self-supervised learning approach for 3D cell segmentation, and relies on the assumption that structural and morphological features of cells can be inferred directly from unlabeled data. This involves leveraging inherent properties such as spatial coherence and local contrast in imaging volumes to distinguish cellular structures. This approach assumes that meaningful representations of cellular boundaries and nuclei can emerge solely from raw 3D volumes. Thus, we strongly recommend that you use WNet3D on stacks that have clear foreground/background segregation and limited noise. Even if your final samples have noise, it is best to train on data that is as clean as you can.


.. important::
For optimal performance, the following should be avoided for training:

- Images with very large, bright regions
- Almost-empty and empty images
- Images with large empty regions or "holes"
- Images with over-exposed pixels/artifacts you do not want to be learned!
- Almost-empty and/or fully empty images, especially if noise is present (it will learn to segment very small objects!).

However, the model may be accomodate:
However, the model may accomodate:

- Uneven brightness distribution
- Varied object shapes and radius
- Noisy images
- Uneven illumination across the image
- Uneven brightness distribution in your image!
- Varied object shapes and radius!
- Noisy images (as long as resolution is sufficient and boundaries are clear)!
- Uneven illumination across the image!

For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement.

Expand Down Expand Up @@ -88,7 +85,7 @@ Common issues troubleshooting
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.


- **The NCuts loss "explodes" after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten.
- **The NCuts loss "explodes" upward after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten.

- **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss.

Expand Down
20 changes: 14 additions & 6 deletions napari_cellseg3d/code_models/worker_training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains the workers used to train the models."""

import platform
import time
from abc import abstractmethod
Expand Down Expand Up @@ -200,7 +201,10 @@ def get_patch_dataset(self, train_transforms):
patch_func = Compose(
[
LoadImaged(keys=["image"], image_only=True),
EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
EnsureChannelFirstd(
keys=["image"],
channel_dim="no_channel",
),
RandSpatialCropSamplesd(
keys=["image"],
roi_size=(
Expand Down Expand Up @@ -235,7 +239,8 @@ def get_dataset_eval(self, eval_dataset_dict):
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(
keys=["image", "label"], channel_dim="no_channel"
keys=["image", "label"],
channel_dim="no_channel",
),
# RandSpatialCropSamplesd(
# keys=["image", "label"],
Expand Down Expand Up @@ -280,7 +285,10 @@ def get_dataset(self, train_transforms):
load_single_images = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]),
EnsureChannelFirstd(
keys=["image"],
channel_dim="no_channel",
),
Orientationd(keys=["image"], axcodes="PLI"),
SpatialPadd(
keys=["image"],
Expand Down Expand Up @@ -1345,9 +1353,9 @@ def get_patch_loader_func(num_samples):
)
sample_loader_eval = get_patch_loader_func(num_val_samples)
else:
num_train_samples = (
num_val_samples
) = self.config.num_samples
num_train_samples = num_val_samples = (
self.config.num_samples
)

sample_loader_train = get_patch_loader_func(
num_train_samples
Expand Down
Loading

0 comments on commit 8c6c306

Please sign in to comment.