From 97bce9ecabefed32580fe3f475f1df24b4590325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Onur=20=C3=9Clgen?= Date: Mon, 30 Oct 2023 17:56:23 +0000 Subject: [PATCH] Make CudaCompute::GetDeformationField() on a par with CPU #92 --- niftyreg_build_version.txt | 2 +- .../cuda/_reg_localTransformation_kernels.cu | 90 +++++++------------ 2 files changed, 32 insertions(+), 60 deletions(-) diff --git a/niftyreg_build_version.txt b/niftyreg_build_version.txt index ec9163d7..6fa50e78 100644 --- a/niftyreg_build_version.txt +++ b/niftyreg_build_version.txt @@ -1 +1 @@ -352 +353 diff --git a/reg-lib/cuda/_reg_localTransformation_kernels.cu b/reg-lib/cuda/_reg_localTransformation_kernels.cu index 69e44967..ba459d22 100755 --- a/reg-lib/cuda/_reg_localTransformation_kernels.cu +++ b/reg-lib/cuda/_reg_localTransformation_kernels.cu @@ -13,14 +13,14 @@ #include "_reg_common_cuda_kernels.cu" /* *************************************************************** */ -__device__ void GetBasisBSplineValues(const double basis, float *values) { - const double ff = Square(basis); - const double fff = Cube(basis); - const double mf = 1.0 - basis; - values[0] = static_cast(Cube(mf) / 6.0); - values[1] = static_cast((3.0 * fff - 6.0 * ff + 4.0) / 6.0); - values[2] = static_cast((-3.0 * fff + 3.0 * ff + 3.0 * basis + 1.0) / 6.0); - values[3] = static_cast(fff / 6.0); +__device__ void GetBasisBSplineValues(const float basis, float *values) { + const float ff = Square(basis); + const float fff = ff * basis; + const float mf = 1.f - basis; + values[0] = Cube(mf) / 6.f; + values[1] = (3.f * fff - 6.f * ff + 4.f) / 6.f; + values[2] = (-3.f * fff + 3.f * ff + 3.f * basis + 1.f) / 6.f; + values[3] = fff / 6.f; } /* *************************************************************** */ __device__ void GetFirstBSplineValues(const float basis, float *values, float *first) { @@ -319,8 +319,6 @@ __global__ void reg_spline_getDeformationField3D(float4 *deformationField, const bool bspline) { const unsigned tid = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (tid >= activeVoxelNumber) return; - const int tid2 = tex1Dfetch(maskTexture, tid); - const auto&& [x, y, z] = reg_indexToDims_cuda(tid2, referenceImageDim); int3 nodePre; float3 basis; @@ -349,6 +347,8 @@ __global__ void reg_spline_getDeformationField3D(float4 *deformationField, nodePre = { Floor(xVoxel), Floor(yVoxel), Floor(zVoxel) }; basis = { xVoxel - float(nodePre.x--), yVoxel - float(nodePre.y--), zVoxel - float(nodePre.z--) }; } else { // starting deformation field is blank - !composition + const int tid2 = tex1Dfetch(maskTexture, tid); + const auto&& [x, y, z] = reg_indexToDims_cuda(tid2, referenceImageDim); // The "nearest previous" node is determined [0,0,0] const float xVoxel = float(x) / controlPointVoxelSpacing.x; const float yVoxel = float(y) / controlPointVoxelSpacing.y; @@ -377,39 +377,20 @@ __global__ void reg_spline_getDeformationField3D(float4 *deformationField, else GetBasisSplineValues(basis.x, xBasis); float4 displacement{}; - for (int c = 0; c < 4; c++) { - float3 tempDisplacement{}; + for (char c = 0; c < 4; c++) { int indexYZ = ((nodePre.z + c) * controlPointImageDim.y + nodePre.y) * controlPointImageDim.x; - for (int b = 0; b < 4; b++) { + const float basisZ = zBasis[sharedMemIndex + c]; + for (char b = 0; b < 4; b++, indexYZ += controlPointImageDim.x) { int indexXYZ = indexYZ + nodePre.x; - const float4& nodeCoefficientA = tex1Dfetch(controlPointTexture, indexXYZ++); - const float4& nodeCoefficientB = tex1Dfetch(controlPointTexture, indexXYZ++); - const float4& nodeCoefficientC = tex1Dfetch(controlPointTexture, indexXYZ++); - const float4& nodeCoefficientD = tex1Dfetch(controlPointTexture, indexXYZ); - - const float& basis = yBasis[sharedMemIndex + b]; - tempDisplacement.x += basis * (nodeCoefficientA.x * xBasis[0] + - nodeCoefficientB.x * xBasis[1] + - nodeCoefficientC.x * xBasis[2] + - nodeCoefficientD.x * xBasis[3]); - - tempDisplacement.y += basis * (nodeCoefficientA.y * xBasis[0] + - nodeCoefficientB.y * xBasis[1] + - nodeCoefficientC.y * xBasis[2] + - nodeCoefficientD.y * xBasis[3]); - - tempDisplacement.z += basis * (nodeCoefficientA.z * xBasis[0] + - nodeCoefficientB.z * xBasis[1] + - nodeCoefficientC.z * xBasis[2] + - nodeCoefficientD.z * xBasis[3]); - - indexYZ += controlPointImageDim.x; + const float basisY = yBasis[sharedMemIndex + b]; + for (char a = 0; a < 4; a++, indexXYZ++) { + const float4& nodeCoeff = tex1Dfetch(controlPointTexture, indexXYZ); + const float xyzBasis = xBasis[a] * basisY * basisZ; + displacement.x += xyzBasis * nodeCoeff.x; + displacement.y += xyzBasis * nodeCoeff.y; + displacement.z += xyzBasis * nodeCoeff.z; + } } - - const float& basis = zBasis[sharedMemIndex + c]; - displacement.x += basis * tempDisplacement.x; - displacement.y += basis * tempDisplacement.y; - displacement.z += basis * tempDisplacement.z; } deformationField[tid] = displacement; } @@ -426,8 +407,6 @@ __global__ void reg_spline_getDeformationField2D(float4 *deformationField, const bool bspline) { const unsigned tid = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (tid >= activeVoxelNumber) return; - const int tid2 = tex1Dfetch(maskTexture, tid); - const auto&& [x, y, z] = reg_indexToDims_cuda(tid2, referenceImageDim); int2 nodePre; float2 basis; @@ -449,6 +428,8 @@ __global__ void reg_spline_getDeformationField2D(float4 *deformationField, nodePre = { Floor(xVoxel), Floor(yVoxel) }; basis = { xVoxel - float(nodePre.x--), yVoxel - float(nodePre.y--) }; } else { // starting deformation field is blank - !composition + const int tid2 = tex1Dfetch(maskTexture, tid); + const auto&& [x, y, z] = reg_indexToDims_cuda(tid2, referenceImageDim); // The "nearest previous" node is determined [0,0,0] const float xVoxel = float(x) / controlPointVoxelSpacing.x; const float yVoxel = float(y) / controlPointVoxelSpacing.y; @@ -469,24 +450,15 @@ __global__ void reg_spline_getDeformationField2D(float4 *deformationField, else GetBasisSplineValues(basis.x, xBasis); float4 displacement{}; - for (int b = 0; b < 4; b++) { + for (char b = 0; b < 4; b++) { int index = (nodePre.y + b) * controlPointImageDim.x + nodePre.x; - - const float4& nodeCoefficientA = tex1Dfetch(controlPointTexture, index++); - const float4& nodeCoefficientB = tex1Dfetch(controlPointTexture, index++); - const float4& nodeCoefficientC = tex1Dfetch(controlPointTexture, index++); - const float4& nodeCoefficientD = tex1Dfetch(controlPointTexture, index); - - const float& basis = yBasis[sharedMemIndex + b]; - displacement.x += basis * (nodeCoefficientA.x * xBasis[0] + - nodeCoefficientB.x * xBasis[1] + - nodeCoefficientC.x * xBasis[2] + - nodeCoefficientD.x * xBasis[3]); - - displacement.y += basis * (nodeCoefficientA.y * xBasis[0] + - nodeCoefficientB.y * xBasis[1] + - nodeCoefficientC.y * xBasis[2] + - nodeCoefficientD.y * xBasis[3]); + const float basis = yBasis[sharedMemIndex + b]; + for (char a = 0; a < 4; a++, index++) { + const float4& nodeCoeff = tex1Dfetch(controlPointTexture, index); + const float xyBasis = xBasis[a] * basis; + displacement.x += xyBasis * nodeCoeff.x; + displacement.y += xyBasis * nodeCoeff.y; + } } deformationField[tid] = displacement; }