Skip to content

Commit

Permalink
Make CudaCompute::GetDeformationField() on a par with CPU #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Oct 30, 2023
1 parent 2f65fc9 commit 97bce9e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 60 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
352
353
90 changes: 31 additions & 59 deletions reg-lib/cuda/_reg_localTransformation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
#include "_reg_common_cuda_kernels.cu"

/* *************************************************************** */
__device__ void GetBasisBSplineValues(const double basis, float *values) {
const double ff = Square(basis);
const double fff = Cube(basis);
const double mf = 1.0 - basis;
values[0] = static_cast<float>(Cube(mf) / 6.0);
values[1] = static_cast<float>((3.0 * fff - 6.0 * ff + 4.0) / 6.0);
values[2] = static_cast<float>((-3.0 * fff + 3.0 * ff + 3.0 * basis + 1.0) / 6.0);
values[3] = static_cast<float>(fff / 6.0);
__device__ void GetBasisBSplineValues(const float basis, float *values) {
const float ff = Square(basis);
const float fff = ff * basis;
const float mf = 1.f - basis;
values[0] = Cube(mf) / 6.f;
values[1] = (3.f * fff - 6.f * ff + 4.f) / 6.f;
values[2] = (-3.f * fff + 3.f * ff + 3.f * basis + 1.f) / 6.f;
values[3] = fff / 6.f;
}
/* *************************************************************** */
__device__ void GetFirstBSplineValues(const float basis, float *values, float *first) {
Expand Down Expand Up @@ -319,8 +319,6 @@ __global__ void reg_spline_getDeformationField3D(float4 *deformationField,
const bool bspline) {
const unsigned tid = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x;
if (tid >= activeVoxelNumber) return;
const int tid2 = tex1Dfetch<int>(maskTexture, tid);
const auto&& [x, y, z] = reg_indexToDims_cuda<true>(tid2, referenceImageDim);
int3 nodePre;
float3 basis;

Expand Down Expand Up @@ -349,6 +347,8 @@ __global__ void reg_spline_getDeformationField3D(float4 *deformationField,
nodePre = { Floor(xVoxel), Floor(yVoxel), Floor(zVoxel) };
basis = { xVoxel - float(nodePre.x--), yVoxel - float(nodePre.y--), zVoxel - float(nodePre.z--) };
} else { // starting deformation field is blank - !composition
const int tid2 = tex1Dfetch<int>(maskTexture, tid);
const auto&& [x, y, z] = reg_indexToDims_cuda<true>(tid2, referenceImageDim);
// The "nearest previous" node is determined [0,0,0]
const float xVoxel = float(x) / controlPointVoxelSpacing.x;
const float yVoxel = float(y) / controlPointVoxelSpacing.y;
Expand Down Expand Up @@ -377,39 +377,20 @@ __global__ void reg_spline_getDeformationField3D(float4 *deformationField,
else GetBasisSplineValues(basis.x, xBasis);

float4 displacement{};
for (int c = 0; c < 4; c++) {
float3 tempDisplacement{};
for (char c = 0; c < 4; c++) {
int indexYZ = ((nodePre.z + c) * controlPointImageDim.y + nodePre.y) * controlPointImageDim.x;
for (int b = 0; b < 4; b++) {
const float basisZ = zBasis[sharedMemIndex + c];
for (char b = 0; b < 4; b++, indexYZ += controlPointImageDim.x) {
int indexXYZ = indexYZ + nodePre.x;
const float4& nodeCoefficientA = tex1Dfetch<float4>(controlPointTexture, indexXYZ++);
const float4& nodeCoefficientB = tex1Dfetch<float4>(controlPointTexture, indexXYZ++);
const float4& nodeCoefficientC = tex1Dfetch<float4>(controlPointTexture, indexXYZ++);
const float4& nodeCoefficientD = tex1Dfetch<float4>(controlPointTexture, indexXYZ);

const float& basis = yBasis[sharedMemIndex + b];
tempDisplacement.x += basis * (nodeCoefficientA.x * xBasis[0] +
nodeCoefficientB.x * xBasis[1] +
nodeCoefficientC.x * xBasis[2] +
nodeCoefficientD.x * xBasis[3]);

tempDisplacement.y += basis * (nodeCoefficientA.y * xBasis[0] +
nodeCoefficientB.y * xBasis[1] +
nodeCoefficientC.y * xBasis[2] +
nodeCoefficientD.y * xBasis[3]);

tempDisplacement.z += basis * (nodeCoefficientA.z * xBasis[0] +
nodeCoefficientB.z * xBasis[1] +
nodeCoefficientC.z * xBasis[2] +
nodeCoefficientD.z * xBasis[3]);

indexYZ += controlPointImageDim.x;
const float basisY = yBasis[sharedMemIndex + b];
for (char a = 0; a < 4; a++, indexXYZ++) {
const float4& nodeCoeff = tex1Dfetch<float4>(controlPointTexture, indexXYZ);
const float xyzBasis = xBasis[a] * basisY * basisZ;
displacement.x += xyzBasis * nodeCoeff.x;
displacement.y += xyzBasis * nodeCoeff.y;
displacement.z += xyzBasis * nodeCoeff.z;
}
}

const float& basis = zBasis[sharedMemIndex + c];
displacement.x += basis * tempDisplacement.x;
displacement.y += basis * tempDisplacement.y;
displacement.z += basis * tempDisplacement.z;
}
deformationField[tid] = displacement;
}
Expand All @@ -426,8 +407,6 @@ __global__ void reg_spline_getDeformationField2D(float4 *deformationField,
const bool bspline) {
const unsigned tid = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x;
if (tid >= activeVoxelNumber) return;
const int tid2 = tex1Dfetch<int>(maskTexture, tid);
const auto&& [x, y, z] = reg_indexToDims_cuda<false>(tid2, referenceImageDim);
int2 nodePre;
float2 basis;

Expand All @@ -449,6 +428,8 @@ __global__ void reg_spline_getDeformationField2D(float4 *deformationField,
nodePre = { Floor(xVoxel), Floor(yVoxel) };
basis = { xVoxel - float(nodePre.x--), yVoxel - float(nodePre.y--) };
} else { // starting deformation field is blank - !composition
const int tid2 = tex1Dfetch<int>(maskTexture, tid);
const auto&& [x, y, z] = reg_indexToDims_cuda<false>(tid2, referenceImageDim);
// The "nearest previous" node is determined [0,0,0]
const float xVoxel = float(x) / controlPointVoxelSpacing.x;
const float yVoxel = float(y) / controlPointVoxelSpacing.y;
Expand All @@ -469,24 +450,15 @@ __global__ void reg_spline_getDeformationField2D(float4 *deformationField,
else GetBasisSplineValues(basis.x, xBasis);

float4 displacement{};
for (int b = 0; b < 4; b++) {
for (char b = 0; b < 4; b++) {
int index = (nodePre.y + b) * controlPointImageDim.x + nodePre.x;

const float4& nodeCoefficientA = tex1Dfetch<float4>(controlPointTexture, index++);
const float4& nodeCoefficientB = tex1Dfetch<float4>(controlPointTexture, index++);
const float4& nodeCoefficientC = tex1Dfetch<float4>(controlPointTexture, index++);
const float4& nodeCoefficientD = tex1Dfetch<float4>(controlPointTexture, index);

const float& basis = yBasis[sharedMemIndex + b];
displacement.x += basis * (nodeCoefficientA.x * xBasis[0] +
nodeCoefficientB.x * xBasis[1] +
nodeCoefficientC.x * xBasis[2] +
nodeCoefficientD.x * xBasis[3]);

displacement.y += basis * (nodeCoefficientA.y * xBasis[0] +
nodeCoefficientB.y * xBasis[1] +
nodeCoefficientC.y * xBasis[2] +
nodeCoefficientD.y * xBasis[3]);
const float basis = yBasis[sharedMemIndex + b];
for (char a = 0; a < 4; a++, index++) {
const float4& nodeCoeff = tex1Dfetch<float4>(controlPointTexture, index);
const float xyBasis = xBasis[a] * basis;
displacement.x += xyBasis * nodeCoeff.x;
displacement.y += xyBasis * nodeCoeff.y;
}
}
deformationField[tid] = displacement;
}
Expand Down

0 comments on commit 97bce9e

Please sign in to comment.