Skip to content

Commit

Permalink
Add composition support for CudaCompute::GetDeformationField() #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Oct 13, 2023
1 parent a8f1232 commit d925b8c
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 251 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
342
343
69 changes: 33 additions & 36 deletions reg-lib/cpu/_reg_localTrans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
} val;
__m128 tempCurrent, tempX, tempY;
#ifdef _WIN32
__declspec(align(16)) DataType temp[4];
__declspec(align(16)) DataType xBasis[4];
__declspec(align(16)) DataType yBasis[4];
union {
__m128 m[16];
Expand All @@ -578,7 +578,7 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
__declspec(align(16)) DataType f[16];
} xyBasis;
#else // _WIN32
DataType temp[4] __attribute__((aligned(16)));
DataType xBasis[4] __attribute__((aligned(16)));
DataType yBasis[4] __attribute__((aligned(16)));
union {
__m128 m[16];
Expand All @@ -594,7 +594,7 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
} xyBasis;
#endif // _WIN32
#else // _USE_SSE
DataType temp[4];
DataType xBasis[4];
DataType yBasis[4];
DataType xyBasis[16];
DataType xControlPointCoordinates[16];
Expand Down Expand Up @@ -626,7 +626,6 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
index = y * deformationField->nx;
oldXpre = oldYpre = 99999999;
for (x = 0; x < deformationField->nx; x++) {

// The previous position at the current pixel position is read
xReal = static_cast<DataType>(fieldPtrX[index]);
yReal = static_cast<DataType>(fieldPtrY[index]);
Expand All @@ -643,8 +642,8 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
xPre = Floor(xVoxel);
basis = xVoxel - static_cast<DataType>(xPre--);
if (basis < 0) basis = 0; //rounding error
if (bspline) get_BSplineBasisValues<DataType>(basis, temp);
else get_SplineBasisValues<DataType>(basis, temp);
if (bspline) get_BSplineBasisValues<DataType>(basis, xBasis);
else get_SplineBasisValues<DataType>(basis, xBasis);

yPre = Floor(yVoxel);
basis = yVoxel - static_cast<DataType>(yPre--);
Expand Down Expand Up @@ -688,7 +687,7 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
coord = 0;
for (b = 0; b < 4; b++) {
for (a = 0; a < 4; a++) {
xyBasis.f[coord++] = temp[a] * yBasis[b];
xyBasis.f[coord++] = xBasis[a] * yBasis[b];
}
}

Expand All @@ -707,7 +706,7 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
#else
for (b = 0; b < 4; b++) {
for (a = 0; a < 4; a++) {
DataType tempValue = temp[a] * yBasis[b];
DataType tempValue = xBasis[a] * yBasis[b];
xReal += xControlPointCoordinates[b * 4 + a] * tempValue;
yReal += yControlPointCoordinates[b * 4 + a] * tempValue;
}
Expand All @@ -728,14 +727,14 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
shared(deformationField, gridVoxelSpacing, splineControlPoint, controlPointPtrX, \
controlPointPtrY, mask, fieldPtrX, fieldPtrY, bspline) \
private(x, a, xPre, yPre, oldXpre, oldYpre, index, xReal, yReal, basis, \
val, temp, yBasis, tempCurrent, xyBasis, tempX, tempY, \
val, xBasis, yBasis, tempCurrent, xyBasis, tempX, tempY, \
xControlPointCoordinates, yControlPointCoordinates)
#else // _USE_SSE
#pragma omp parallel for default(none) \
shared(deformationField, gridVoxelSpacing, splineControlPoint, controlPointPtrX, \
controlPointPtrY, mask, fieldPtrX, fieldPtrY, bspline) \
private(x, a, xPre, yPre, oldXpre, oldYpre, index, xReal, yReal, basis, coord, \
temp, yBasis, xyBasis, xControlPointCoordinates, yControlPointCoordinates)
xBasis, yBasis, xyBasis, xControlPointCoordinates, yControlPointCoordinates)
#endif // _USE_SEE
#endif // _OPENMP
for (y = 0; y < deformationField->ny; y++) {
Expand All @@ -744,21 +743,21 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,

yPre = static_cast<int>(static_cast<DataType>(y) / gridVoxelSpacing[1]);
basis = static_cast<DataType>(y) / gridVoxelSpacing[1] - static_cast<DataType>(yPre);
if (basis < 0) basis = 0; //rounding error
if (basis < 0) basis = 0; // rounding error
if (bspline) get_BSplineBasisValues<DataType>(basis, yBasis);
else get_SplineBasisValues<DataType>(basis, yBasis);

for (x = 0; x < deformationField->nx; x++) {
xPre = static_cast<int>(static_cast<DataType>(x) / gridVoxelSpacing[0]);
basis = static_cast<DataType>(x) / gridVoxelSpacing[0] - static_cast<DataType>(xPre);
if (basis < 0) basis = 0; //rounding error
if (bspline) get_BSplineBasisValues<DataType>(basis, temp);
else get_SplineBasisValues<DataType>(basis, temp);
if (basis < 0) basis = 0; // rounding error
if (bspline) get_BSplineBasisValues<DataType>(basis, xBasis);
else get_SplineBasisValues<DataType>(basis, xBasis);
#if _USE_SSE
val.f[0] = static_cast<float>(temp[0]);
val.f[1] = static_cast<float>(temp[1]);
val.f[2] = static_cast<float>(temp[2]);
val.f[3] = static_cast<float>(temp[3]);
val.f[0] = static_cast<float>(xBasis[0]);
val.f[1] = static_cast<float>(xBasis[1]);
val.f[2] = static_cast<float>(xBasis[2]);
val.f[3] = static_cast<float>(xBasis[3]);
tempCurrent = val.m;
for (a = 0; a < 4; a++) {
val.m = _mm_set_ps1(static_cast<float>(yBasis[a]));
Expand All @@ -767,10 +766,10 @@ void reg_cubic_spline_getDeformationField2D(nifti_image *splineControlPoint,
#else
coord = 0;
for (a = 0; a < 4; a++) {
xyBasis[coord++] = temp[0] * yBasis[a];
xyBasis[coord++] = temp[1] * yBasis[a];
xyBasis[coord++] = temp[2] * yBasis[a];
xyBasis[coord++] = temp[3] * yBasis[a];
xyBasis[coord++] = xBasis[0] * yBasis[a];
xyBasis[coord++] = xBasis[1] * yBasis[a];
xyBasis[coord++] = xBasis[2] * yBasis[a];
xyBasis[coord++] = xBasis[3] * yBasis[a];
}
#endif
if (oldXpre != xPre || oldYpre != yPre) {
Expand Down Expand Up @@ -837,7 +836,7 @@ void reg_cubic_spline_getDeformationField3D(nifti_image *splineControlPoint,
int *mask,
bool composition,
bool bspline,
bool force_no_lut = false) {
bool forceNoLut = false) {
#if _USE_SSE
union {
__m128 m;
Expand Down Expand Up @@ -1111,7 +1110,7 @@ void reg_cubic_spline_getDeformationField3D(nifti_image *splineControlPoint,
#endif // _USE_SSE

// Assess if lookup table can be used
if (gridVoxelSpacing[0] == 5. && gridVoxelSpacing[0] == 5. && gridVoxelSpacing[0] == 5. && force_no_lut == false) {
if (gridVoxelSpacing[0] == 5. && gridVoxelSpacing[0] == 5. && gridVoxelSpacing[0] == 5. && forceNoLut == false) {
// Assign a single array that will contain all coefficients
DataType *coefficients = (DataType*)malloc(125 * 64 * sizeof(DataType));
// Compute and store all required coefficients
Expand Down Expand Up @@ -1462,7 +1461,7 @@ void reg_spline_getDeformationField(nifti_image *splineControlPoint,
int *mask,
bool composition,
bool bspline,
bool force_no_lut) {
bool forceNoLut) {
if (splineControlPoint->datatype != deformationField->datatype)
NR_FATAL_ERROR("The spline control point image and the deformation field image are expected to be of the same type");

Expand All @@ -1471,11 +1470,11 @@ void reg_spline_getDeformationField(nifti_image *splineControlPoint,
NR_FATAL_ERROR("SSE computation has only been implemented for single precision");
#endif

bool MrPropre = false;
if (mask == nullptr) {
unique_ptr<int[]> currentMask;
if (!mask) {
// Active voxel are all superior to -1, 0 thus will do !
MrPropre = true;
mask = (int*)calloc(NiftiImage::calcVoxelNumber(deformationField, 3), sizeof(int));
currentMask.reset(new int[NiftiImage::calcVoxelNumber(deformationField, 3)]());
mask = currentMask.get();
}

// Check if an affine initialisation is required
Expand Down Expand Up @@ -1519,10 +1518,10 @@ void reg_spline_getDeformationField(nifti_image *splineControlPoint,
} else {
switch (deformationField->datatype) {
case NIFTI_TYPE_FLOAT32:
reg_cubic_spline_getDeformationField3D<float>(splineControlPoint, deformationField, mask, composition, bspline, force_no_lut);
reg_cubic_spline_getDeformationField3D<float>(splineControlPoint, deformationField, mask, composition, bspline, forceNoLut);
break;
case NIFTI_TYPE_FLOAT64:
reg_cubic_spline_getDeformationField3D<double>(splineControlPoint, deformationField, mask, composition, bspline, force_no_lut);
reg_cubic_spline_getDeformationField3D<double>(splineControlPoint, deformationField, mask, composition, bspline, forceNoLut);
break;
default:
NR_FATAL_ERROR("Only single or double precision is implemented for deformation field");
Expand All @@ -1534,12 +1533,10 @@ void reg_spline_getDeformationField(nifti_image *splineControlPoint,
if (splineControlPoint->ext_list[1].edata != nullptr) {
reg_affine_getDeformationField(reinterpret_cast<mat44*>(splineControlPoint->ext_list[1].edata),
deformationField,
true, //composition
true, // composition
mask);
}
}
if (MrPropre)
free(mask);
}
/* *************************************************************** */
template<class DataType>
Expand Down Expand Up @@ -3497,7 +3494,7 @@ void reg_spline_getFlowFieldFromVelocityGrid(nifti_image *velocityFieldGrid,
flowField->intent_p1 = DISP_VEL_FIELD;
reg_getDeformationFromDisplacement(flowField);

// fake the number of extension here to avoid the second half of the affine
// Fake the number of extension here to avoid the second half of the affine
int oldNumExt = velocityFieldGrid->num_ext;
if (oldNumExt > 1)
velocityFieldGrid->num_ext = 1;
Expand All @@ -3508,7 +3505,7 @@ void reg_spline_getFlowFieldFromVelocityGrid(nifti_image *velocityFieldGrid,
reg_spline_getDeformationField(velocityFieldGrid,
flowField,
nullptr, // mask
true, //composition
true, // composition
true); // bspline

velocityFieldGrid->num_ext = oldNumExt;
Expand Down
2 changes: 1 addition & 1 deletion reg-lib/cpu/_reg_localTrans.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void reg_spline_getDeformationField(nifti_image *controlPointGridImage,
int *mask = nullptr,
bool composition = false,
bool bspline = true,
bool force_no_lut = false);
bool forceNoLut = false);
/* *************************************************************** */
/** @brief Upsample an image from voxel space to node space using
* millimetre correspondences.
Expand Down
92 changes: 45 additions & 47 deletions reg-lib/cpu/_reg_splineBasis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,36 +460,34 @@ template void set_second_order_bspline_basis_values<double>(double*, double*, do
template <class DataType>
void get_SlidedValues(DataType& defX,
DataType& defY,
const int X,
const int Y,
const int x,
const int y,
const DataType *defPtrX,
const DataType *defPtrY,
const mat44 *df_voxel2Real,
const mat44 *dfVoxel2Real,
const int *dim,
const bool displacement) {
int newX = X;
int newY = Y;
if (X < 0) {
int newX = x;
if (x < 0)
newX = 0;
} else if (X >= dim[1]) {
else if (x >= dim[1])
newX = dim[1] - 1;
}
if (Y < 0) {

int newY = y;
if (y < 0)
newY = 0;
} else if (Y >= dim[2]) {
else if (y >= dim[2])
newY = dim[2] - 1;
}

DataType shiftValueX = 0;
DataType shiftValueY = 0;
if (!displacement) {
int shiftIndexX = X - newX;
int shiftIndexY = Y - newY;
shiftValueX = shiftIndexX * df_voxel2Real->m[0][0] +
shiftIndexY * df_voxel2Real->m[0][1];
shiftValueY = shiftIndexX * df_voxel2Real->m[1][0] +
shiftIndexY * df_voxel2Real->m[1][1];
const int shiftIndexX = x - newX;
const int shiftIndexY = y - newY;
shiftValueX = shiftIndexX * dfVoxel2Real->m[0][0] + shiftIndexY * dfVoxel2Real->m[0][1];
shiftValueY = shiftIndexX * dfVoxel2Real->m[1][0] + shiftIndexY * dfVoxel2Real->m[1][1];
}
size_t index = newY * dim[1] + newX;
const int index = newY * dim[1] + newX;
defX = defPtrX[index] + shiftValueX;
defY = defPtrY[index] + shiftValueY;
}
Expand All @@ -500,54 +498,54 @@ template <class DataType>
void get_SlidedValues(DataType& defX,
DataType& defY,
DataType& defZ,
const int X,
const int Y,
const int Z,
const int x,
const int y,
const int z,
const DataType *defPtrX,
const DataType *defPtrY,
const DataType *defPtrZ,
const mat44 *df_voxel2Real,
const mat44 *dfVoxel2Real,
const int *dim,
const bool displacement) {
int newX = X;
int newY = Y;
int newZ = Z;
if (X < 0) {
int newX = x;
if (x < 0)
newX = 0;
} else if (X >= dim[1]) {
else if (x >= dim[1])
newX = dim[1] - 1;
}
if (Y < 0) {

int newY = y;
if (y < 0)
newY = 0;
} else if (Y >= dim[2]) {
else if (y >= dim[2])
newY = dim[2] - 1;
}
if (Z < 0) {

int newZ = z;
if (z < 0)
newZ = 0;
} else if (Z >= dim[3]) {
else if (z >= dim[3])
newZ = dim[3] - 1;
}

DataType shiftValueX = 0;
DataType shiftValueY = 0;
DataType shiftValueZ = 0;
if (!displacement) {
int shiftIndexX = X - newX;
int shiftIndexY = Y - newY;
int shiftIndexZ = Z - newZ;
const int shiftIndexX = x - newX;
const int shiftIndexY = y - newY;
const int shiftIndexZ = z - newZ;
shiftValueX =
shiftIndexX * df_voxel2Real->m[0][0] +
shiftIndexY * df_voxel2Real->m[0][1] +
shiftIndexZ * df_voxel2Real->m[0][2];
shiftIndexX * dfVoxel2Real->m[0][0] +
shiftIndexY * dfVoxel2Real->m[0][1] +
shiftIndexZ * dfVoxel2Real->m[0][2];
shiftValueY =
shiftIndexX * df_voxel2Real->m[1][0] +
shiftIndexY * df_voxel2Real->m[1][1] +
shiftIndexZ * df_voxel2Real->m[1][2];
shiftIndexX * dfVoxel2Real->m[1][0] +
shiftIndexY * dfVoxel2Real->m[1][1] +
shiftIndexZ * dfVoxel2Real->m[1][2];
shiftValueZ =
shiftIndexX * df_voxel2Real->m[2][0] +
shiftIndexY * df_voxel2Real->m[2][1] +
shiftIndexZ * df_voxel2Real->m[2][2];
shiftIndexX * dfVoxel2Real->m[2][0] +
shiftIndexY * dfVoxel2Real->m[2][1] +
shiftIndexZ * dfVoxel2Real->m[2][2];
}
size_t index = (newZ * dim[2] + newY) * dim[1] + newX;
const int index = (newZ * dim[2] + newY) * dim[1] + newX;
defX = defPtrX[index] + shiftValueX;
defY = defPtrY[index] + shiftValueY;
defZ = defPtrZ[index] + shiftValueZ;
Expand Down
14 changes: 7 additions & 7 deletions reg-lib/cpu/_reg_splineBasis.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,24 @@ void get_SplineBasisValues(DataType basis,
template <class DataType>
void get_SlidedValues(DataType &defX,
DataType &defY,
const int X,
const int Y,
const int x,
const int y,
const DataType *defPtrX,
const DataType *defPtrY,
const mat44 *df_voxel2Real,
const mat44 *dfVoxel2Real,
const int *dim,
const bool displacement);
template <class DataType>
void get_SlidedValues(DataType &defX,
DataType &defY,
DataType &defZ,
const int X,
const int Y,
const int Z,
const int x,
const int y,
const int z,
const DataType *defPtrX,
const DataType *defPtrY,
const DataType *defPtrZ,
const mat44 *df_voxel2Real,
const mat44 *dfVoxel2Real,
const int *dim,
const bool displacement);

Expand Down
Loading

0 comments on commit d925b8c

Please sign in to comment.