Skip to content

Commit

Permalink
Add multi-timepoint support for MeasureTest #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Nov 15, 2023
1 parent bc4c672 commit 86db434
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
362
363
41 changes: 22 additions & 19 deletions reg-test/reg_test_regr_measure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ class MeasureTest {

// Create 2D reference, floating, control point grid and local weight similarity images
constexpr NiftiImage::dim_t size = 16;
vector<NiftiImage::dim_t> dim{ size, size };
constexpr NiftiImage::dim_t timePoints = 1;
vector<NiftiImage::dim_t> dim{ size, size, 1, timePoints };
NiftiImage reference2d(dim, NIFTI_TYPE_FLOAT32);
NiftiImage floating2d(dim, NIFTI_TYPE_FLOAT32);
NiftiImage controlPointGrid2d(CreateControlPointGrid(reference2d));
NiftiImage localWeightSim2d(dim, NIFTI_TYPE_FLOAT32);

// Create 3D reference, floating, control point grid and local weight similarity images
dim.push_back(size);
dim[2] = size;
NiftiImage reference3d(dim, NIFTI_TYPE_FLOAT32);
NiftiImage floating3d(dim, NIFTI_TYPE_FLOAT32);
NiftiImage controlPointGrid3d(CreateControlPointGrid(reference3d));
Expand Down Expand Up @@ -63,7 +64,7 @@ class MeasureTest {

// Create the data container for the regression test
const std::string measureNames[]{ "NMI"s, "SSD"s, "DTI"s, "LNCC"s, "KLD"s, "MIND"s, "MINDSSC"s };
const MeasureType testMeasures[]{ MeasureType::Nmi, MeasureType::Ssd };
constexpr MeasureType testMeasures[]{ MeasureType::Nmi, MeasureType::Ssd };
vector<TestData> testData;
for (auto&& measure : testMeasures) {
for (int sym = 0; sym < 2; ++sym) {
Expand Down Expand Up @@ -137,9 +138,9 @@ class MeasureTest {
unique_ptr<reg_measure> measureCuda{ measureCreatorCuda->Create(measureType) };

// Initialise the measures
for (int i = 0; i < referenceCpu->nt; ++i) {
measureCpu->SetTimePointWeight(i, 1.0);
measureCuda->SetTimePointWeight(i, 1.0);
for (int t = 0; t < referenceCpu->nt; t++) {
measureCpu->SetTimePointWeight(t, 1.0);
measureCuda->SetTimePointWeight(t, 1.0);
}
measureCreatorCpu->Initialise(*measureCpu, *contentCpu, contentCpuBw.get());
measureCreatorCuda->Initialise(*measureCuda, *contentCuda, contentCudaBw.get());
Expand All @@ -162,24 +163,26 @@ class MeasureTest {
}
const double simMeasureCuda = measureCuda->GetSimilarityMeasureValue();

// Compute the similarity measure gradient for CPU
constexpr int timepoint = 0;
// Compute the similarity measure gradients
contentCpu->ZeroVoxelBasedMeasureGradient();
computeCpu->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), timepoint);
if (isSymmetric) {
contentCpuBw->ZeroVoxelBasedMeasureGradient();
computeCpuBw->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), timepoint);
}
measureCpu->GetVoxelBasedSimilarityMeasureGradient(timepoint);

// Compute the similarity measure gradient for CUDA
contentCuda->ZeroVoxelBasedMeasureGradient();
computeCuda->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), timepoint);
if (isSymmetric) {
contentCpuBw->ZeroVoxelBasedMeasureGradient();
contentCudaBw->ZeroVoxelBasedMeasureGradient();
computeCudaBw->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), timepoint);
}
measureCuda->GetVoxelBasedSimilarityMeasureGradient(timepoint);
for (int t = 0; t < referenceCpu->nt; t++) {
// Compute the similarity measure gradient for CPU
computeCpu->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), t);
if (isSymmetric)
computeCpuBw->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), t);
measureCpu->GetVoxelBasedSimilarityMeasureGradient(t);

// Compute the similarity measure gradient for CUDA
computeCuda->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), t);
if (isSymmetric)
computeCudaBw->GetImageGradient(1, std::numeric_limits<float>::quiet_NaN(), t);
measureCuda->GetVoxelBasedSimilarityMeasureGradient(t);
}

// Get the voxel-based similarity measure gradients
NiftiImage voxelBasedGradCpu(contentCpu->GetVoxelBasedMeasureGradient(), NiftiImage::Copy::Image);
Expand Down

0 comments on commit 86db434

Please sign in to comment.