Skip to content

Commit

Permalink
Merge pull request #46 from Noam-Dori/import_spots_from_labels
Browse files Browse the repository at this point in the history
Implemented label image to spot ellipsoid plugin
  • Loading branch information
stefanhahmann authored Oct 9, 2023
2 parents 7d354f7 + 7d31978 commit 96e7695
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,31 @@

import bdv.viewer.Source;
import mpicbg.spim.data.sequence.TimePoint;
import mpicbg.spim.data.sequence.VoxelDimensions;
import net.imglib2.Cursor;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.util.LinAlgHelpers;
import net.imglib2.util.Pair;
import net.imglib2.util.ValuePair;
import net.imglib2.view.Views;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.mastodon.mamut.MamutAppModel;
import org.mastodon.mamut.model.Model;
import org.mastodon.mamut.model.ModelGraph;
import org.mastodon.mamut.model.Spot;
import org.scijava.Context;
import org.scijava.app.StatusService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.invoke.MethodHandles;
import java.math.BigInteger;
import java.util.List;

import static java.math.BigInteger.valueOf;

public class ImportSpotFromLabelsController
{

Expand All @@ -30,111 +36,163 @@ public class ImportSpotFromLabelsController

private final List< TimePoint > timePoints;

private final Source< RealType< ? > > source;
private final Source< ? extends RealType< ? > > source;

private final StatusService statusService;

public ImportSpotFromLabelsController( final MamutAppModel appModel, final Context context, int labelChannelIndex )
private final VoxelDimensions voxelDimensions;
private final double sigma;

public ImportSpotFromLabelsController( final MamutAppModel appModel, final Context context, int labelChannelIndex, double sigma )
{
// NB: Use the dimensions of the first source and the first time point only without checking if they are equal in other sources and time points.
this( appModel.getModel(),
appModel.getSharedBdvData().getSpimData().getSequenceDescription().getTimePoints().getTimePointsOrdered(),
Cast.unchecked( appModel.getSharedBdvData().getSources().get( labelChannelIndex ).getSpimSource() ), context
);
Cast.unchecked( appModel.getSharedBdvData().getSources().get( labelChannelIndex ).getSpimSource() ), context,
appModel.getSharedBdvData().getSpimData().getSequenceDescription().getViewSetups().get( 0 ).getVoxelSize(), sigma);
}

protected ImportSpotFromLabelsController(
final Model model, final List< TimePoint > timePoints, final Source< RealType< ? > > source, final Context context
)
final Model model, final List< TimePoint > timePoints, final Source< ? extends RealType< ? > > source, final Context context,
VoxelDimensions voxelDimensions, double sigma)
{
this.modelGraph = model.getGraph();
this.timePoints = timePoints;
this.source = source;
this.statusService = context.service( StatusService.class );
this.statusService = context.getService( StatusService.class );
this.voxelDimensions = voxelDimensions;
this.sigma = sigma;
}

public void createSpotsFromLabels()
{
Spot spot = modelGraph.addVertex();
int timepointId = 0;
double[] center = new double[] { 50, 50, 50 };
double[][] cov = new double[][] { { 400, 20, -10 }, { 20, 200, 30 }, { -10, 30, 100 } };
spot.init( timepointId, center, cov );
int numTimepoints = timePoints.size();

for ( TimePoint frame : timePoints )
{
int frameId = frame.getId();
long[] dimensions = source.getSource( frameId, 0 ).dimensionsAsLongArray();
final RandomAccessibleInterval< RealType< ? > > img = source.getSource( frameId, 0 );
final RandomAccessibleInterval< IntegerType< ? > > img = Cast.unchecked(source.getSource( frameId, 0 ));
for ( int d = 0; d < dimensions.length; d++ )
logger.debug( "Dimension {}, : {}", d, dimensions[ d ] );
IterableInterval< RealType< ? > > iterable = Views.iterable( img );
double[] mean = computeMean( iterable, 42 );
double[][] coviarance = computeCovariance( iterable, mean, 42 );
statusService.showProgress( frameId, numTimepoints );

createSpotsFromLabelImage(img, frameId);
if (statusService != null) {
statusService.showProgress(frameId + 1, numTimepoints);
}

}
}

private void createSpotsFromLabelImage(@NotNull RandomAccessibleInterval<IntegerType<?>> img, int timepointId) {
logger.debug("Computing mean, covariance of all labels at time-point t={}", timepointId);

// get the maximum value possible to learn how many objects need to be instantiated
// this is fine because we expect maximum occupancy here.
// we also subtract the background to truly get the number of elements.
Pair<Integer, Integer> minAndMax = getPixelValueInterval(img);

int numLabels = minAndMax.getB() - minAndMax.getA();
int[] count = new int[numLabels]; // counts the number of pixels in each label, for normalization
long[][] sum = new long[numLabels][3]; // sums up the positions of the label pixels, used for the 1D means
BigInteger[][][] mixedSum = new BigInteger[numLabels][3][3]; // sums up the estimates of mixed coordinates (like xy). Used for covariances.

readImageSumPositions(img, count, sum, mixedSum, minAndMax.getA());

createSpotsFromSums(timepointId, numLabels, count, sum, mixedSum);
}

/**
* Computes the mean position of the pixels whose value equals {@code labelValue}.
* Read the image and get its maximum and minimum values
* @param img an image to read and process
* @return A pair of values (min, max) that represent the minimum and maximum pixel values in the image
* @author Noam Dori
*/
private static double[] computeMean( IterableInterval< RealType< ? > > iterable, int labelValue )
{
logger.debug( "Computing mean of label, {} ", labelValue );
Cursor< RealType< ? > > cursor = iterable.cursor();
double[] sum = new double[ 3 ];
double[] position = new double[ 3 ];
long counter = 0;
while ( cursor.hasNext() )
@Contract("_ -> new")
private static @NotNull Pair<Integer, Integer> getPixelValueInterval(RandomAccessibleInterval<IntegerType<?>> img) {
// read the picture to sum everything up
int min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE;
Cursor<IntegerType<?>> cursor = Views.iterable(img).cursor();
while (cursor.hasNext())
{
counter++;
int pixelValue = ( int ) cursor.next().getRealDouble();
if ( pixelValue == labelValue )
{
cursor.localize( position );
LinAlgHelpers.add( sum, position, sum );
counter++;
int val = cursor.next().getInteger(); // we ignore 0 as it is BG
if (min > val) {
min = val;
}
if (max < val) {
max = val;
}
}
logger.debug( "Computed mean of label {}. Searched {} pixels.", labelValue, counter );
LinAlgHelpers.scale( sum, 1. / counter, sum );
return sum;
return new ValuePair<>(min, max);
}

/**
* Computes the covariance matrix of the pixels whose value equals {@code labelValue}.
* Use the gathered information to generate all the spots for the given timepoint.
* @param timepointId the timepoint of the image the spots should belong to.
* @param numLabels the maximum value encountered in the image. Also equal to the number of labels.
* @param count the 0D sums (counts). Dimensions: [labelIdx].
* @param sum the 1D sums, i.e S[X]. Dimensions: [labelIdx, coord]
* @param mixedSum the 2D sums, i.e S[XY]. Dimensions: [labelIdx, coord, coord]
* @implNote The covariance formula used here is not the definition COV(X,Y) = E[(X - E[X])(Y - E[Y])]
* but instead its simplification COV(X, Y) = E[XY] - E[X]E[Y].
* Read more <a href=https://en.wikipedia.org/wiki/Covariance#Definition>here</a>.
* Previously there was a factor of 5 placed on the covariance.
* I removed it, but it might be neccesary for some reason.
* @author Noam Dori
*/
private static double[][] computeCovariance( IterableInterval< RealType< ? > > iterable, double[] mean, int labelValue )
{
Cursor< RealType< ? > > cursor = iterable.cursor();
long counter = 0;
double[] position = new double[ 3 ];
double[][] covariance = new double[ 3 ][ 3 ];
cursor.reset();
while ( cursor.hasNext() )
{
int pixelValue = ( int ) cursor.next().getRealDouble();
if ( pixelValue == labelValue )
{
cursor.localize( position );
LinAlgHelpers.subtract( position, mean, position );
for ( int i = 0; i < 3; i++ )
for ( int j = 0; j < 3; j++ )
covariance[ i ][ j ] += position[ i ] * position[ j ];
counter++;
private void createSpotsFromSums(int timepointId, int numLabels, int[] count, long[][] sum, BigInteger[][][] mixedSum) {
// combine the sums into mean and covariance matrices, then add the corresponding spot
logger.debug("adding spots for the {} labels found", numLabels);
double[] mean = new double[3];
double[][] cov = new double[3][3];
for (int labelIdx = 0; labelIdx < numLabels; labelIdx++) {
for (int i = 0; i < 3; i++) {
mean[i] = sum[labelIdx][i] / (double) count[labelIdx] * voxelDimensions.dimension(i);
for (int j = i; j < 3; j++) { // the covariance matrix is symmetric!
cov[i][j] = mixedSum[labelIdx][i][j].multiply(valueOf(count[labelIdx]))
.subtract(valueOf(sum[labelIdx][i]).multiply(valueOf(sum[labelIdx][j])))
.doubleValue() / Math.pow(count[labelIdx], 2);
cov[i][j] *= Math.pow(sigma, 2) * voxelDimensions.dimension(i) * voxelDimensions.dimension(j);
if (i != j) {
cov[j][i] = cov[i][j];
}
}
}
modelGraph.addVertex().init(timepointId, mean, cov);
}

scale( covariance, 5. / counter ); // I don't know why the factor 5 is needed. But it works.
return covariance;
}

private static void scale( double[][] covariance, double factor )
{
for ( int i = 0; i < 3; i++ )
for ( int j = 0; j < 3; j++ )
covariance[ i ][ j ] *= factor;
/**
* Reads the image and prepares the coordinates of all labels to obtain the 0D (count), 1D (sums), and 2D (mixed)
* sums to prep the ground for the mean and covariance estimates.
* @param img the pointer to the image to read.
* @param count an empty array to store the 0D sums (counts). Dimensions: [labelIdx].
* @param sum an empty array to store the 1D sums, i.e S[X]. Dimensions: [labelIdx, coord]
* @param mixedSum an empty array to store the 2D sums, i.e S[XY]. Dimensions: [labelIdx, coord, coord]
* @param bg the pixel value of the background. Since unsigned is annoying in Fiji, this subtracts the bg value from the label.
* @author Noam Dori
*/
private static void readImageSumPositions(RandomAccessibleInterval<IntegerType<?>> img, int[] count,
long[][] sum, BigInteger[][][] mixedSum, int bg) {
// read the picture to sum everything up
int[] position = new int[3];
Cursor<IntegerType<?>> cursor = Views.iterable(img).cursor();
while (cursor.hasNext())
{
int labelIdx = cursor.next().getInteger() - bg - 1; // we ignore 0 as it is BG
if (labelIdx < 0) {
continue;
}
cursor.localize(position);
count[labelIdx]++;
for (int i = 0; i < 3; i++) {
sum[labelIdx][i] += position[i];
for (int j = i; j < 3; j++) { // the covariance matrix is symmetric!
mixedSum[labelIdx][i][j] =
mixedSum[labelIdx][i][j].add(valueOf(position[i]).multiply(valueOf(position[j])));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ public class ImportSpotsFromLabelsView implements Command
@Parameter(label = "Channel index of labels", min = "0")
private int labelChannelIndex = 0;

@Parameter(label = "Sigma", min = "0", description = "#deviations from center to form border")
private double sigma = 2.2;

@Override
public void run()
{
ImportSpotFromLabelsController controller = new ImportSpotFromLabelsController( appModel, context, labelChannelIndex );
ImportSpotFromLabelsController controller = new ImportSpotFromLabelsController( appModel, context, labelChannelIndex, sigma );
controller.createSpotsFromLabels();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.mastodon.mamut.segment;

import bdv.util.AbstractSource;
import bdv.util.RandomAccessibleIntervalSource;
import mpicbg.spim.data.sequence.FinalVoxelDimensions;
import mpicbg.spim.data.sequence.TimePoint;
import mpicbg.spim.data.sequence.VoxelDimensions;
import net.imglib2.RandomAccess;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.type.numeric.integer.IntType;
import org.junit.Before;
import org.junit.Test;
import org.mastodon.mamut.model.Model;
import org.mastodon.mamut.model.Spot;
import org.mastodon.mamut.model.branch.ModelBranchGraph;
import org.scijava.Context;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;

import static org.junit.Assert.assertTrue;

public class ImportSpotFromLabelsControllerTest
{
private Model model;

private int timepoint;

@Before
public void setUp()
{
model = new Model();
ModelBranchGraph modelBranchGraph = model.getBranchGraph();
modelBranchGraph.graphRebuilt();
timepoint = 0;
}

@Test
public void testGetEllipsoidFromImage() {
AbstractSource<IntType> img = createImage();

Context context = new Context(true);
TimePoint timePoint = new TimePoint( timepoint );
List< TimePoint > timePoints = Collections.singletonList( timePoint );
VoxelDimensions voxelDimensions = new FinalVoxelDimensions("um", 0.16, 0.16, 1);
ImportSpotFromLabelsController controller = new ImportSpotFromLabelsController(model, timePoints, img, context, voxelDimensions, 2.2);

controller.createSpotsFromLabels();

Iterator<Spot> iter = model.getGraph().vertices().iterator();
assertTrue(iter.hasNext());

Spot s = iter.next();

s.getDoublePosition(0);
}

private static AbstractSource< IntType > createImage()
{
Img<IntType> img = new ArrayImgFactory<>(new IntType()).create(4, 4, 4);
RandomAccess<IntType> ra = img.randomAccess();
ra.setPositionAndGet(1, 1, 1).set(1);
ra.setPositionAndGet(2, 2, 2).set(1);
ra.setPositionAndGet(3, 3, 3).set(1);

return new RandomAccessibleIntervalSource<>( img, new IntType(), new AffineTransform3D(), "Segmentation" );
}
}

0 comments on commit 96e7695

Please sign in to comment.