Skip to content

Commit

Permalink
Cameras (#198)
Browse files Browse the repository at this point in the history
* opencv_lens_undistortion

* fix k4 bug for undistortion, support fisheye

* support k3 k4 k5 k6

* fix _opencv_len_distortion; format

* naming: len->lens
liruilong940607 authored Apr 9, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4f7965c commit ebeb5dd
Showing 7 changed files with 850 additions and 31 deletions.
211 changes: 211 additions & 0 deletions nerfacc/cameras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

from . import cuda as _C


def opencv_lens_undistortion(
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
) -> Tensor:
"""Undistort the opencv distortion.
Note:
This function is not differentiable to any inputs.
Args:
uv: (..., 2) UV coordinates.
params: (..., N) or (N) OpenCV distortion parameters. We support
N = 0, 1, 2, 4, 8. If N = 0, we return the input uv directly.
If N = 1, we assume the input is {k1}. If N = 2, we assume the
input is {k1, k2}. If N = 4, we assume the input is {k1, k2, p1, p2}.
If N = 8, we assume the input is {k1, k2, p1, p2, k3, k4, k5, k6}.
Returns:
(..., 2) undistorted UV coordinates.
"""
assert uv.shape[-1] == 2
assert params.shape[-1] in [0, 1, 2, 4, 8]

if params.shape[-1] == 0:
return uv
elif params.shape[-1] < 8:
params = F.pad(params, (0, 8 - params.shape[-1]), "constant", 0)
assert params.shape[-1] == 8

batch_shape = uv.shape[:-1]
params = torch.broadcast_to(params, batch_shape + (params.shape[-1],))

return _C.opencv_lens_undistortion(
uv.contiguous(), params.contiguous(), eps, iters
)


def opencv_lens_undistortion_fisheye(
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
) -> Tensor:
"""Undistort the opencv distortion of {k1, k2, k3, k4}.
Note:
This function is not differentiable to any inputs.
Args:
uv: (..., 2) UV coordinates.
params: (..., 4) or (4) OpenCV distortion parameters.
Returns:
(..., 2) undistorted UV coordinates.
"""
assert uv.shape[-1] == 2
assert params.shape[-1] == 4
batch_shape = uv.shape[:-1]
params = torch.broadcast_to(params, batch_shape + (params.shape[-1],))

return _C.opencv_lens_undistortion_fisheye(
uv.contiguous(), params.contiguous(), eps, iters
)


def _opencv_lens_distortion(uv: Tensor, params: Tensor) -> Tensor:
"""The opencv camera distortion of {k1, k2, p1, p2, k3, k4, k5, k6}.
See https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html for more details.
"""
k1, k2, p1, p2, k3, k4, k5, k6 = torch.unbind(params, dim=-1)
s1, s2, s3, s4 = 0, 0, 0, 0
u, v = torch.unbind(uv, dim=-1)
r2 = u * u + v * v
r4 = r2**2
r6 = r4 * r2
ratial = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (
1 + k4 * r2 + k5 * r4 + k6 * r6
)
fx = 2 * p1 * u * v + p2 * (r2 + 2 * u * u) + s1 * r2 + s2 * r4
fy = 2 * p2 * u * v + p1 * (r2 + 2 * v * v) + s3 * r2 + s4 * r4
return torch.stack([u * ratial + fx, v * ratial + fy], dim=-1)


def _opencv_lens_distortion_fisheye(
uv: Tensor, params: Tensor, eps: float = 1e-10
) -> Tensor:
"""The opencv camera distortion of {k1, k2, k3, p1, p2}.
See https://docs.opencv.org/4.x/db/d58/group__calib3d__fisheye.html for more details.
Args:
uv: (..., 2) UV coordinates.
params: (..., 4) or (4) OpenCV distortion parameters.
Returns:
(..., 2) distorted UV coordinates.
"""
assert params.shape[-1] == 4, f"Invalid params shape: {params.shape}"
k1, k2, k3, k4 = torch.unbind(params, dim=-1)
u, v = torch.unbind(uv, dim=-1)
r = torch.sqrt(u * u + v * v)
theta = torch.atan(r)
theta_d = theta * (
1
+ k1 * theta**2
+ k2 * theta**4
+ k3 * theta**6
+ k4 * theta**8
)
scale = theta_d / torch.clamp(r, min=eps)
return uv * scale[..., None]


@torch.jit.script
def _compute_residual_and_jacobian(
x: Tensor, y: Tensor, xd: Tensor, yd: Tensor, params: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
assert params.shape[-1] == 8

k1, k2, p1, p2, k3, k4, k5, k6 = torch.unbind(params, dim=-1)

# let r(x, y) = x^2 + y^2;
# alpha(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3;
# beta(x, y) = 1 + k4 * r(x, y) + k5 * r(x, y) ^2 + k6 * r(x, y)^3;
# d(x, y) = alpha(x, y) / beta(x, y);
r = x * x + y * y
alpha = 1.0 + r * (k1 + r * (k2 + r * k3))
beta = 1.0 + r * (k4 + r * (k5 + r * k6))
d = alpha / beta

# The perfect projection is:
# xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
# yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
#
# Let's define
#
# fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
# fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
#
# We are looking for a solution that satisfies
# fx(x, y) = fy(x, y) = 0;
fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd

# Compute derivative of alpha, beta over r.
alpha_r = k1 + r * (2.0 * k2 + r * (3.0 * k3))
beta_r = k4 + r * (2.0 * k5 + r * (3.0 * k6))

# Compute derivative of d over [x, y]
d_r = (alpha_r * beta - alpha * beta_r) / (beta * beta)
d_x = 2.0 * x * d_r
d_y = 2.0 * y * d_r

# Compute derivative of fx over x and y.
fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y

# Compute derivative of fy over x and y.
fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y

return fx, fy, fx_x, fx_y, fy_x, fy_y


@torch.jit.script
def _opencv_lens_undistortion(
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
) -> Tensor:
"""Same as opencv_lens_undistortion(), but native PyTorch.
Took from with bug fix and modification.
https://github.com/nerfstudio-project/nerfstudio/blob/ec603634edbd61b13bdf2c598fda8c993370b8f7/nerfstudio/cameras/camera_utils.py
"""
assert uv.shape[-1] == 2
assert params.shape[-1] in [0, 1, 2, 4, 8]

if params.shape[-1] == 0:
return uv
elif params.shape[-1] < 8:
params = F.pad(params, (0, 8 - params.shape[-1]), "constant", 0.0)
assert params.shape[-1] == 8

# Initialize from the distorted point.
x, y = x0, y0 = torch.unbind(uv, dim=-1)

zeros = torch.zeros_like(x)
for _ in range(iters):
fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
x=x, y=y, xd=x0, yd=y0, params=params
)
denominator = fy_x * fx_y - fx_x * fy_y
mask = torch.abs(denominator) > eps

x_numerator = fx * fy_y - fy * fx_y
y_numerator = fy * fx_x - fx * fy_x
step_x = torch.where(mask, x_numerator / denominator, zeros)
step_y = torch.where(mask, y_numerator / denominator, zeros)

x = x + step_x
y = y + step_y

return torch.stack([x, y], dim=-1)
160 changes: 160 additions & 0 deletions nerfacc/cameras2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
Seems like both colmap and nerfstudio are based on OpenCV's camera model.
References:
- nerfstudio: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/cameras/cameras.py
- opencv:
- https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga69f2545a8b62a6b0fc2ee060dc30559d
- https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
- https://docs.opencv.org/4.x/db/d58/group__calib3d__fisheye.html
- https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fisheye.cpp#L321
- https://github.com/opencv/opencv/blob/17234f82d025e3bbfbf611089637e5aa2038e7b8/modules/calib3d/src/distortion_model.hpp
- https://github.com/opencv/opencv/blob/8d0fbc6a1e9f20c822921e8076551a01e58cd632/modules/calib3d/src/undistort.dispatch.cpp#L578
- colmap: https://github.com/colmap/colmap/blob/dev/src/base/camera_models.h
- calcam: https://euratom-software.github.io/calcam/html/intro_theory.html
- blender:
- https://docs.blender.org/manual/en/latest/render/cycles/object_settings/cameras.html#fisheye-lens-polynomial
- https://github.com/blender/blender/blob/03cc3b94c94c38767802bccac4e9384ab704065a/intern/cycles/kernel/kernel_projection.h
- lensfun: https://lensfun.github.io/manual/v0.3.2/annotated.html
- OpenCV and Blender has different fisheye camera models
- https://stackoverflow.com/questions/73270140/pipeline-for-fisheye-distortion-and-undistortion-with-blender-and-opencv
"""
from typing import Literal, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

from . import cuda as _C


def ray_directions_from_uvs(
uvs: Tensor, # [..., 2]
Ks: Tensor, # [..., 3, 3]
params: Optional[Tensor] = None, # [..., M]
) -> Tensor:
"""Create ray directions from uvs and camera parameters in OpenCV format.
Args:
uvs: UV coordinates on image plane. (In pixel unit)
Ks: Camera intrinsics.
params: Camera distortion parameters. See `opencv.undistortPoints` for details.
Returns:
Normalized ray directions in camera space.
"""
u, v = torch.unbind(uvs + 0.5, dim=-1)
fx, fy = Ks[..., 0, 0], Ks[..., 1, 1]
cx, cy = Ks[..., 0, 2], Ks[..., 1, 2]

# undo intrinsics
xys = torch.stack([(u - cx) / fx, (v - cy) / fy], dim=-1) # [..., 2]

# undo lens distortion
if params is not None:
M = params.shape[-1]

if M == 14: # undo tilt projection
R, R_inv = opencv_tilt_projection_matrix(params[..., -2:])
xys_homo = F.pad(xys, (0, 1), value=1.0) # [..., 3]
xys_homo = torch.einsum(
"...ij,...j->...i", R_inv, xys_homo
) # [..., 3]
xys = xys_homo[..., :2]
homo = xys_homo[..., 2:]
xys /= torch.where(homo != 0.0, homo, torch.ones_like(homo))

xys = opencv_lens_undistortion(xys, params) # [..., 2]

# normalized homogeneous coordinates
dirs = F.pad(xys, (0, 1), value=1.0) # [..., 3]
dirs = F.normalize(dirs, dim=-1) # [..., 3]
return dirs


def opencv_lens_undistortion(
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
) -> Tensor:
"""Undistort the opencv distortion of {k1, k2, k3, k4, p1, p2}.
Note:
This function is not differentiable to any inputs.
Args:
uv: (..., 2) UV coordinates.
params: (..., 6) or (6) OpenCV distortion parameters.
Returns:
(..., 2) undistorted UV coordinates.
"""
assert uv.shape[-1] == 2
assert params.shape[-1] == 6
batch_shape = uv.shape[:-1]
params = torch.broadcast_to(params, batch_shape + (6,))

return _C.opencv_lens_undistortion(
uv.contiguous(), params.contiguous(), eps, iters
)


def opencv_tilt_projection_matrix(tau: Tensor) -> Tensor:
"""Create a tilt projection matrix.
Reference:
https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
Args:
tau: (..., 2) tilt angles.
Returns:
(..., 3, 3) tilt projection matrix.
"""

cosx, cosy = torch.unbind(torch.cos(tau), -1)
sinx, siny = torch.unbind(torch.sin(tau), -1)
one = torch.ones_like(tau)
zero = torch.zeros_like(tau)

Rx = torch.stack(
[one, zero, zero, zero, cosx, sinx, zero, -sinx, cosx], -1
).reshape(*tau.shape[:-1], 3, 3)
Ry = torch.stack(
[cosy, zero, -siny, zero, one, zero, siny, zero, cosy], -1
).reshape(*tau.shape[:-1], 3, 3)
Rxy = torch.matmul(Ry, Rx)
Rz = torch.stack(
[
Rxy[..., 2, 2],
zero,
-Rxy[..., 0, 2],
zero,
Rxy[..., 2, 2],
-Rxy[..., 1, 2],
zero,
zero,
one,
],
-1,
).reshape(*tau.shape[:-1], 3, 3)
R = torch.matmul(Rz, Rxy)

inv = 1.0 / Rxy[..., 2, 2]
Rz_inv = torch.stack(
[
inv,
zero,
inv * Rxy[..., 0, 2],
zero,
inv,
inv * Rxy[..., 1, 2],
zero,
zero,
one,
],
-1,
).reshape(*tau.shape[:-1], 3, 3)
R_inv = torch.matmul(Rxy.transpose(-1, -2), Rz_inv)
return R, R_inv
6 changes: 6 additions & 0 deletions nerfacc/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -38,3 +38,9 @@ def call_cuda(*args, **kwargs):
# pdf
importance_sampling = _make_lazy_cuda_func("importance_sampling")
searchsorted = _make_lazy_cuda_func("searchsorted")

# camera
opencv_lens_undistortion = _make_lazy_cuda_func("opencv_lens_undistortion")
opencv_lens_undistortion_fisheye = _make_lazy_cuda_func(
"opencv_lens_undistortion_fisheye"
)
183 changes: 183 additions & 0 deletions nerfacc/cuda/csrc/camera.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#include <torch/extension.h>

#include "include/utils_cuda.cuh"
#include "include/utils_camera.cuh"


namespace {
namespace device {

__global__ void opencv_lens_undistortion_fisheye(
const int64_t N,
const float* uv,
const float* params,
const int criteria_iters,
const float criteria_eps,
float* uv_out,
bool* success)
{
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < N; tid += blockDim.x * gridDim.x)
{
success[tid] = iterative_opencv_lens_undistortion_fisheye(
uv[tid * 2 + 0],
uv[tid * 2 + 1],
params[tid * 4 + 0], // k1
params[tid * 4 + 1], // k2
params[tid * 4 + 2], // k3
params[tid * 4 + 3], // k4
criteria_iters,
criteria_eps,
uv_out[tid * 2 + 0],
uv_out[tid * 2 + 1]
);
}
}

__global__ void opencv_lens_undistortion(
const int64_t N,
const int64_t n_params,
const float* uv,
const float* params,
const float eps,
const int max_iterations,
float* uv_out)
{
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < N; tid += blockDim.x * gridDim.x)
{
if (n_params == 5) {
radial_and_tangential_undistort(
uv[tid * 2 + 0],
uv[tid * 2 + 1],
params[tid * n_params + 0], // k1
params[tid * n_params + 1], // k2
params[tid * n_params + 4], // k3
0.f, // k4
0.f, // k5
0.f, // k6
params[tid * n_params + 2], // p1
params[tid * n_params + 3], // p2
eps,
max_iterations,
uv_out[tid * 2 + 0],
uv_out[tid * 2 + 1]);
} else if (n_params == 8) {
radial_and_tangential_undistort(
uv[tid * 2 + 0],
uv[tid * 2 + 1],
params[tid * n_params + 0], // k1
params[tid * n_params + 1], // k2
params[tid * n_params + 4], // k3
params[tid * n_params + 5], // k4
params[tid * n_params + 6], // k5
params[tid * n_params + 7], // k6
params[tid * n_params + 2], // p1
params[tid * n_params + 3], // p2
eps,
max_iterations,
uv_out[tid * 2 + 0],
uv_out[tid * 2 + 1]);
} else if (n_params == 12) {
bool success = iterative_opencv_lens_undistortion(
uv[tid * 2 + 0],
uv[tid * 2 + 1],
params[tid * 12 + 0], // k1
params[tid * 12 + 1], // k2
params[tid * 12 + 2], // k3
params[tid * 12 + 3], // k4
params[tid * 12 + 4], // k5
params[tid * 12 + 5], // k6
params[tid * 12 + 6], // p1
params[tid * 12 + 7], // p2
params[tid * 12 + 8], // s1
params[tid * 12 + 9], // s2
params[tid * 12 + 10], // s3
params[tid * 12 + 11], // s4
max_iterations,
uv_out[tid * 2 + 0],
uv_out[tid * 2 + 1]
);
if (!success) {
uv_out[tid * 2 + 0] = uv[tid * 2 + 0];
uv_out[tid * 2 + 1] = uv[tid * 2 + 1];
}
}
}
}


} // namespace device
} // namespace


torch::Tensor opencv_lens_undistortion(
const torch::Tensor& uv, // [..., 2]
const torch::Tensor& params, // [..., 5] or [..., 12]
const float eps,
const int max_iterations)
{
DEVICE_GUARD(uv);
CHECK_INPUT(uv);
CHECK_INPUT(params);
TORCH_CHECK(uv.ndimension() == params.ndimension());
TORCH_CHECK(uv.size(-1) == 2, "uv must have shape [..., 2]");
TORCH_CHECK(params.size(-1) == 5 || params.size(-1) == 8 || params.size(-1) == 12);

int64_t N = uv.numel() / 2;
int64_t n_params = params.size(-1);

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512;
int64_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, N));
dim3 blocks = dim3(min(max_blocks, ceil_div<int64_t>(N, threads.x)));

auto uv_out = torch::empty_like(uv);
device::opencv_lens_undistortion<<<blocks, threads, 0, stream>>>(
N,
n_params,
uv.data_ptr<float>(),
params.data_ptr<float>(),
eps,
max_iterations,
uv_out.data_ptr<float>());

return uv_out;
}

torch::Tensor opencv_lens_undistortion_fisheye(
const torch::Tensor& uv, // [..., 2]
const torch::Tensor& params, // [..., 4]
const float criteria_eps,
const int criteria_iters)
{
DEVICE_GUARD(uv);
CHECK_INPUT(uv);
CHECK_INPUT(params);
TORCH_CHECK(uv.ndimension() == params.ndimension());
TORCH_CHECK(uv.size(-1) == 2, "uv must have shape [..., 2]");
TORCH_CHECK(params.size(-1) == 4);

int64_t N = uv.numel() / 2;

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512;
int64_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, N));
dim3 blocks = dim3(min(max_blocks, ceil_div<int64_t>(N, threads.x)));

auto uv_out = torch::empty_like(uv);
auto success = torch::empty(
uv.sizes().slice(0, uv.ndimension() - 1), uv.options().dtype(torch::kBool));
device::opencv_lens_undistortion_fisheye<<<blocks, threads, 0, stream>>>(
N,
uv.data_ptr<float>(),
params.data_ptr<float>(),
criteria_iters,
criteria_eps,
uv_out.data_ptr<float>(),
success.data_ptr<bool>());

return uv_out;
}
201 changes: 201 additions & 0 deletions nerfacc/cuda/csrc/include/utils_camera.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/

#include "utils_cuda.cuh"

#define PI 3.14159265358979323846

namespace {
namespace device {

// https://github.com/JamesPerlman/TurboNeRF/blob/75f1228d41b914b0a768a876d2a851f3b3213a58/src/utils/camera-kernels.cuh
inline __device__ void _compute_residual_and_jacobian(
// inputs
float x, float y,
float xd, float yd,
float k1, float k2, float k3, float k4, float k5, float k6,
float p1, float p2,
// outputs
float& fx, float& fy,
float& fx_x, float& fx_y,
float& fy_x, float& fy_y
) {
// let r(x, y) = x^2 + y^2;
// alpha(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3;
// beta(x, y) = 1 + k4 * r(x, y) + k5 * r(x, y) ^2 + k6 * r(x, y)^3;
// d(x, y) = alpha(x, y) / beta(x, y);
const float r = x * x + y * y;
const float alpha = 1.0f + r * (k1 + r * (k2 + r * k3));
const float beta = 1.0f + r * (k4 + r * (k5 + r * k6));
const float d = alpha / beta;

// The perfect projection is:
// xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
// yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);

// Let's define
// fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
// fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;

// We are looking for a solution that satisfies
// fx(x, y) = fy(x, y) = 0;

fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd;
fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd;

// Compute derivative of alpha, beta over r.
const float alpha_r = k1 + r * (2.0 * k2 + r * (3.0 * k3));
const float beta_r = k4 + r * (2.0 * k5 + r * (3.0 * k6));

// Compute derivative of d over [x, y]
const float d_r = (alpha_r * beta - alpha * beta_r) / (beta * beta);
const float d_x = 2.0 * x * d_r;
const float d_y = 2.0 * y * d_r;

// Compute derivative of fx over x and y.
fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x;
fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y;

// Compute derivative of fy over x and y.
fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x;
fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y;
}

// https://github.com/JamesPerlman/TurboNeRF/blob/75f1228d41b914b0a768a876d2a851f3b3213a58/src/utils/camera-kernels.cuh
inline __device__ void radial_and_tangential_undistort(
float xd, float yd,
float k1, float k2, float k3, float k4, float k5, float k6,
float p1, float p2,
const float& eps,
const int& max_iterations,
float& x, float& y
) {
// Initial guess.
x = xd;
y = yd;

// Newton's method.
for (int i = 0; i < max_iterations; ++i) {
float fx, fy, fx_x, fx_y, fy_x, fy_y;

_compute_residual_and_jacobian(
x, y,
xd, yd,
k1, k2, k3, k4, k5, k6,
p1, p2,
fx, fy,
fx_x, fx_y, fy_x, fy_y
);

// Compute the Jacobian.
const float det = fx_y * fy_x - fx_x * fy_y;
if (fabs(det) < eps) {
break;
}

// Compute the update.
const float dx = (fx * fy_y - fy * fx_y) / det;
const float dy = (fy * fx_x - fx * fy_x) / det;

// Update the solution.
x += dx;
y += dy;

// Check for convergence.
if (fabs(dx) < eps && fabs(dy) < eps) {
break;
}
}
}

// not good
// https://github.com/opencv/opencv/blob/8d0fbc6a1e9f20c822921e8076551a01e58cd632/modules/calib3d/src/undistort.dispatch.cpp#L578
inline __device__ bool iterative_opencv_lens_undistortion(
float u, float v,
float k1, float k2, float k3, float k4, float k5, float k6,
float p1, float p2, float s1, float s2, float s3, float s4,
int iters,
// outputs
float& x, float& y)
{
x = u;
y = v;
for(int i = 0; i < iters; i++)
{
float r2 = x*x + y*y;
float icdist = (1 + ((k6*r2 + k5)*r2 + k4)*r2) / (1 + ((k3*r2 + k2)*r2 + k1)*r2);
if (icdist < 0) return false;
float deltaX = 2*p1*x*y + p2*(r2 + 2*x*x) + s1*r2 + s2*r2*r2;
float deltaY = p1*(r2 + 2*y*y) + 2*p2*x*y + s3*r2 + s4*r2*r2;
x = (u - deltaX) * icdist;
y = (v - deltaY) * icdist;
}
return true;
}

// https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fisheye.cpp#L321
inline __device__ bool iterative_opencv_lens_undistortion_fisheye(
float u, float v,
float k1, float k2, float k3, float k4,
int criteria_iters,
float criteria_eps,
// outputs
float& u_out, float& v_out)
{
// image point (u, v) to world point (x, y)
float theta_d = sqrt(u * u + v * v);

// the current camera model is only valid up to 180 FOV
// for larger FOV the loop below does not converge
// clip values so we still get plausible results for super fisheye images > 180 grad
theta_d = min(max(-PI/2., theta_d), PI/2.);

bool converged = false;
float theta = theta_d;

float scale = 0.0;

if (fabs(theta_d) > criteria_eps)
{
// compensate distortion iteratively using Newton method
for (int j = 0; j < criteria_iters; j++)
{
double theta2 = theta*theta, theta4 = theta2*theta2, theta6 = theta4*theta2, theta8 = theta6*theta2;
double k0_theta2 = k1 * theta2, k1_theta4 = k2 * theta4, k2_theta6 = k3 * theta6, k3_theta8 = k4 * theta8;
/* new_theta = theta - theta_fix, theta_fix = f0(theta) / f0'(theta) */
double theta_fix = (theta * (1 + k0_theta2 + k1_theta4 + k2_theta6 + k3_theta8) - theta_d) /
(1 + 3*k0_theta2 + 5*k1_theta4 + 7*k2_theta6 + 9*k3_theta8);
theta = theta - theta_fix;

if (fabs(theta_fix) < criteria_eps)
{
converged = true;
break;
}
}

scale = std::tan(theta) / theta_d;
}
else
{
converged = true;
}

// theta is monotonously increasing or decreasing depending on the sign of theta
// if theta has flipped, it might converge due to symmetry but on the opposite of the camera center
// so we can check whether theta has changed the sign during the optimization
bool theta_flipped = ((theta_d < 0 && theta > 0) || (theta_d > 0 && theta < 0));

if (converged && !theta_flipped)
{
u_out = u * scale;
v_out = v * scale;
}

return converged;
}


} // namespace device
} // namespace
78 changes: 47 additions & 31 deletions nerfacc/cuda/csrc/nerfacc.cpp
Original file line number Diff line number Diff line change
@@ -88,43 +88,59 @@ std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key);

// cameras
torch::Tensor opencv_lens_undistortion(
const torch::Tensor& uv, // [..., 2]
const torch::Tensor& params, // [..., 6]
const float eps,
const int max_iterations);
torch::Tensor opencv_lens_undistortion_fisheye(
const torch::Tensor& uv, // [..., 2]
const torch::Tensor& params, // [..., 4]
const float criteria_eps,
const int criteria_iters);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(is_cub_available); // TODO: check this function

_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
_REG_FUNC(inclusive_prod_backward);
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);
_REG_FUNC(is_cub_available); // TODO: check this function

_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
_REG_FUNC(inclusive_prod_backward);
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);

_REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted);

_REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted);
_REG_FUNC(opencv_lens_undistortion);
_REG_FUNC(opencv_lens_undistortion_fisheye);
#undef _REG_FUNC

m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));

py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);
py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);

py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);
py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);

py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
}
42 changes: 42 additions & 0 deletions tests/test_camera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Tuple

import pytest
import torch
import tqdm
from torch import Tensor

device = "cuda:0"


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
@torch.no_grad()
def test_opencv_lens_undistortion():
from nerfacc.cameras import (
_opencv_lens_distortion,
_opencv_lens_distortion_fisheye,
_opencv_lens_undistortion,
opencv_lens_undistortion,
opencv_lens_undistortion_fisheye,
)

torch.manual_seed(42)

x = torch.rand((3, 1000, 2), device=device)

params = torch.rand((8), device=device) * 0.01
x_undistort = opencv_lens_undistortion(x, params, 1e-5, 10)
_x_undistort = _opencv_lens_undistortion(x, params, 1e-5, 10)
assert torch.allclose(x_undistort, _x_undistort, atol=1e-5)
x_distort = _opencv_lens_distortion(x_undistort, params)
assert torch.allclose(x, x_distort, atol=1e-5), (x - x_distort).abs().max()
# print(x[0, 0], x_distort[0, 0], x_undistort[0, 0])

params = torch.rand((4), device=device) * 0.01
x_undistort = opencv_lens_undistortion_fisheye(x, params, 1e-5, 10)
x_distort = _opencv_lens_distortion_fisheye(x_undistort, params)
assert torch.allclose(x, x_distort, atol=1e-5), (x - x_distort).abs().max()
# print(x[0, 0], x_distort[0, 0], x_undistort[0, 0])


if __name__ == "__main__":
test_opencv_lens_undistortion()

0 comments on commit ebeb5dd

Please sign in to comment.