Skip to content

Commit

Permalink
Update crf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MMathisLab authored Dec 23, 2024
1 parent 6d8adbb commit 1fafa5f
Showing 1 changed file with 12 additions and 39 deletions.
51 changes: 12 additions & 39 deletions napari_cellseg3d/code_models/crf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implements the CRF post-processing step for the WNet3D.
"""Implements the CRF post-processing step for WNet3D.
The CRF requires the following parameters:
Expand All @@ -16,7 +16,8 @@
Philipp Krähenbühl and Vladlen Koltun
NIPS 2011
Implemented using the pydense library available at https://github.com/lucasb-eyer/pydensecrf.
Implemented using the pydensecrf library available at https://github.com/lucasb-eyer/pydensecrf.
However, this is not maintained, thus we maintain this pacakge at https://github.com/AdaptiveMotorControlLab/pydensecrf.
"""

import importlib
Expand All @@ -37,33 +38,11 @@
unary_from_softmax,
)

__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard"
__credits__ = [
"Yves Paychère",
"Colin Hofmann",
"Cyril Achard",
"Philipp Krähenbühl",
"Vladlen Koltun",
"Liang-Chieh Chen",
"George Papandreou",
"Iasonas Kokkinos",
"Kevin Murphy",
"Alan L. Yuille",
"Xide Xia",
"Brian Kulis",
"Lucas Beyer",
]


def correct_shape_for_crf(image, desired_dims=4):
"""Corrects the shape of the image to be compatible with the CRF post-processing step."""
logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}")
logger.debug(f"Image shape: {image.shape}")
if len(image.shape) > desired_dims:
# if image.shape[0] > 1:
# raise ValueError(
# f"Image shape {image.shape} might have several channels"
# )
image = np.squeeze(image, axis=0)
elif len(image.shape) < desired_dims:
image = np.expand_dims(image, axis=0)
Expand All @@ -72,7 +51,7 @@ def correct_shape_for_crf(image, desired_dims=4):


def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5):
"""CRF post-processing step for the W-Net, applied to a batch of images.
"""CRF post-processing step for the WNet3D, applied to a batch of images.
Args:
images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images.
Expand Down Expand Up @@ -100,7 +79,7 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5):


def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):
"""Implements the CRF post-processing step for the W-Net.
"""Implements the CRF post-processing step for the WNet3D.
Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506.
Implemented using the pydensecrf library.
Expand All @@ -120,18 +99,14 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):
"""
if not CRF_INSTALLED:
logger.info(
"pydensecrf not installed, CRF post-processing will not be available. "
"Please install by running : pip install pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master"
"This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step."
"pydensecrf not installed, therefore CRF post-processing will not be available! Please install the package. "
"Please install by running: pip install pydensecrf2 "
)
return None

d = dcrf.DenseCRF(
image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0]
)
# print(f"Image shape : {image.shape}")
# print(f"Prob shape : {prob.shape}")
# d = dcrf.DenseCRF(262144, 3) # npoints, nlabels

# Get unary potentials from softmax probabilities
U = unary_from_softmax(prob)
Expand Down Expand Up @@ -165,7 +140,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):


def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info):
"""Implements the CRF post-processing step for the W-Net.
"""Implements the CRF post-processing step for the WNet3D.
Args:
image (np.ndarray): Array of shape (C, H, W, D) containing the input image.
Expand Down Expand Up @@ -202,7 +177,7 @@ def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info):


class CRFWorker(GeneratorWorker):
"""Worker for the CRF post-processing step for the W-Net."""
"""Worker for the CRF post-processing step for the WNet3D."""

def __init__(
self,
Expand Down Expand Up @@ -230,14 +205,12 @@ def __init__(
self.log = log

def _run_crf_job(self):
"""Runs the CRF post-processing step for the W-Net."""
"""Runs the CRF post-processing step for the WNet3D."""
if not CRF_INSTALLED:
logger.info(
"pydensecrf not installed, CRF post-processing will not be available. "
"Please install by running : pip install pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master"
"This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step."
"pydensecrf not installed, therefore CRF post-processing will not be available! Please install the package. "
"Please install by running: pip install pydensecrf2 "
)
# raise ImportError("pydensecrf is not installed.")

if len(self.images) != len(self.labels):
raise ValueError("Number of images and labels must be the same.")
Expand Down

0 comments on commit 1fafa5f

Please sign in to comment.