Skip to content

Commit

Permalink
Modernized bundle adjustment interface (colmap#2896)
Browse files Browse the repository at this point in the history
* Defines a single consistent abstract interface for bundle adjustment.
* Reuses functionality between different bundle adjusters using
composition rather than inheritance.
* Extracts the pose prior alignment to the alignment module. Ideally
needs tests but this remains a TODO.
* The problem setup is now performed when the bundle adjuster is created
and, as such, disentangled from the solving of the problem. This renders
the current awkward SetUpProblem unnecessary.
* The idea later is then to pass a bundle adjuster object to a
simplified covariance estimation interface.

---------

Co-authored-by: B1ueber2y <[email protected]>
  • Loading branch information
ahojnnes and B1ueber2y authored Nov 17, 2024
1 parent 01ca00a commit d065cea
Show file tree
Hide file tree
Showing 12 changed files with 1,091 additions and 1,237 deletions.
210 changes: 16 additions & 194 deletions pycolmap/examples/custom_bundle_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,204 +7,23 @@

import copy

import pyceres
import pyceres # noqa F401

import pycolmap
from pycolmap import logging


class PyBundleAdjuster:
# Python implementation of COLMAP bundle adjuster with pyceres
def __init__(
self,
options: pycolmap.BundleAdjustmentOptions,
config: pycolmap.BundleAdjustmentConfig,
):
self.options = options
self.config = config
self.problem = pyceres.Problem()
self.summary = pyceres.SolverSummary()
self.camera_ids = set()
self.point3D_num_observations = dict()

def solve(self, reconstruction: pycolmap.Reconstruction):
loss = self.options.create_loss_function()
self.set_up_problem(reconstruction, loss)
if self.problem.num_residuals() == 0:
return False
solver_options = self.set_up_solver_options(
self.problem, self.options.solver_options
)
pyceres.solve(solver_options, self.problem, self.summary)
return True

def set_up_problem(
self,
reconstruction: pycolmap.Reconstruction,
loss: pyceres.LossFunction,
):
assert reconstruction is not None
self.problem = pyceres.Problem()
for image_id in self.config.image_ids:
self.add_image_to_problem(image_id, reconstruction, loss)
for point3D_id in self.config.variable_point3D_ids:
self.add_point_to_problem(point3D_id, reconstruction, loss)
for point3D_id in self.config.constant_point3D_ids:
self.add_point_to_problem(point3D_id, reconstruction, loss)
self.parameterize_cameras(reconstruction)
self.parameterize_points(reconstruction)
return self.problem

def set_up_solver_options(
self, problem: pyceres.Problem, solver_options: pyceres.SolverOptions
):
bundle_adjuster = pycolmap.BundleAdjuster(self.options, self.config)
return bundle_adjuster.set_up_solver_options(problem, solver_options)

def add_image_to_problem(
self,
image_id: int,
reconstruction: pycolmap.Reconstruction,
loss: pyceres.LossFunction,
):
image = reconstruction.images[image_id]
pose = image.cam_from_world
camera = reconstruction.cameras[image.camera_id]
constant_cam_pose = (
not self.options.refine_extrinsics
) or self.config.has_constant_cam_pose(image.image_id)
num_observations = 0
for point2D in image.points2D:
if not point2D.has_point3D():
continue
num_observations += 1
if point2D.point3D_id not in self.point3D_num_observations:
self.point3D_num_observations[point2D.point3D_id] = 0
self.point3D_num_observations[point2D.point3D_id] += 1
point3D = reconstruction.points3D[point2D.point3D_id]
assert point3D.track.length() > 1
if constant_cam_pose:
cost = pycolmap.cost_functions.ReprojErrorCost(
camera.model, pose, point2D.xy
)
self.problem.add_residual_block(
cost, loss, [point3D.xyz, camera.params]
)
else:
cost = pycolmap.cost_functions.ReprojErrorCost(
camera.model, point2D.xy
)
self.problem.add_residual_block(
cost,
loss,
[
pose.rotation.quat,
pose.translation,
point3D.xyz,
camera.params,
],
)
if num_observations > 0:
self.camera_ids.add(image.camera_id)
# Set pose parameterization
if not constant_cam_pose:
self.problem.set_manifold(
pose.rotation.quat, pyceres.QuaternionManifold()
)
if self.config.has_constant_cam_positions(image_id):
constant_position_idxs = self.config.constant_cam_positions(
image_id
)
self.problem.set_manifold(
pose.translation,
pyceres.SubsetManifold(3, constant_position_idxs),
)

def add_point_to_problem(
self,
point3D_id: int,
reconstruction: pycolmap.Reconstruction,
loss: pyceres.LossFunction,
):
point3D = reconstruction.points3D[point3D_id]
if point3D_id in self.point3D_num_observations:
if (
self.point3D_num_observations[point3D_id]
== point3D.track.length()
):
return
else:
self.point3D_num_observations[point3D_id] = 0
for track_el in point3D.track.elements:
if self.config.has_image(track_el.image_id):
continue
self.point3D_num_observations[point3D_id] += 1
image = reconstruction.images[track_el.image_id]
camera = reconstruction.cameras[image.camera_id]
point2D = image.point2D(track_el.point2D_idx)
if image.camera_id not in self.camera_ids:
self.camera_ids.add(image.camera_id)
self.config.set_constant_cam_intrinsics(image.camera_id)
cost = pycolmap.cost_functions.ReprojErrorCost(
camera.model, image.cam_from_world, point2D.xy
)
self.problem.add_residual_block(
cost, loss, [point3D.xyz, camera.params]
)

def parameterize_cameras(self, reconstruction: pycolmap.Reconstruction):
constant_camera = (
(not self.options.refine_focal_length)
and (not self.options.refine_principal_point)
and (not self.options.refine_extra_params)
)
for camera_id in self.camera_ids:
camera = reconstruction.cameras[camera_id]
if constant_camera or self.config.has_constant_cam_intrinsics(
camera_id
):
self.problem.set_parameter_block_constant(camera.params)
continue
const_camera_params = []
if not self.options.refine_focal_length:
const_camera_params.extend(camera.focal_length_idxs())
if not self.options.refine_principal_point:
const_camera_params.extend(camera.principal_point_idxs())
if not self.options.refine_extra_params:
const_camera_params.extend(camera.extra_point_idxs())
if len(const_camera_params) > 0:
self.problem.set_manifold(
camera.params,
pyceres.SubsetManifold(
len(camera.params), const_camera_params
),
)

def parameterize_points(self, reconstruction: pycolmap.Reconstruction):
for (
point3D_id,
num_observations,
) in self.point3D_num_observations.items():
point3D = reconstruction.points3D[point3D_id]
if point3D.track.length() > num_observations:
self.problem.set_parameter_block_constant(point3D.xyz)
for point3D_id in self.config.constant_point3D_ids:
point3D = reconstruction.points3D[point3D_id]
self.problem.set_parameter_block_constant(point3D.xyz)


def solve_bundle_adjustment(reconstruction, ba_options, ba_config):
bundle_adjuster = pycolmap.BundleAdjuster(ba_options, ba_config)
# alternative equivalent python-based bundle adjustment (slower):
# bundle_adjuster = PyBundleAdjuster(ba_options, ba_config)
bundle_adjuster.set_up_problem(
reconstruction, ba_options.create_loss_function()
)
solver_options = bundle_adjuster.set_up_solver_options(
bundle_adjuster.problem, ba_options.solver_options
bundle_adjuster = pycolmap.create_default_bundle_adjuster(
ba_options, ba_config, reconstruction
)
summary = pyceres.SolverSummary()
pyceres.solve(solver_options, bundle_adjuster.problem, summary)
summary = bundle_adjuster.solve()
# Alternatively, you can customize the existing problem or options as:
# solver_options = ba_options.create_solver_options(
# ba_config, bundle_adjuster.problem
# )
# summary = pyceres.SolverSummary()
# pyceres.solve(solver_options, bundle_adjuster.problem, summary)
return summary


Expand Down Expand Up @@ -240,11 +59,14 @@ def adjust_global_bundle(mapper, mapper_options, ba_options):
ba_config.set_constant_cam_pose(image_id)

# Fix 7-DOFs of the bundle adjustment problem
ba_config.set_constant_cam_pose(reg_image_ids[0])
reg_image_ids_it = iter(reg_image_ids)
first_reg_image_id = next(reg_image_ids_it)
second_reg_image_id = next(reg_image_ids_it)
ba_config.set_constant_cam_pose(first_reg_image_id)
if (not mapper_options.fix_existing_images) or (
reg_image_ids[1] not in mapper.existing_image_ids
second_reg_image_id not in mapper.existing_image_ids
):
ba_config.set_constant_cam_positions(reg_image_ids[1], [0])
ba_config.set_constant_cam_positions(second_reg_image_id, [0])

# Run bundle adjustment
summary = solve_bundle_adjustment(reconstruction, ba_options_tmp, ba_config)
Expand Down
20 changes: 20 additions & 0 deletions pycolmap/examples/custom_incremental_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Python reimplementation of the C++ incremental mapper with equivalent logic.
"""

import argparse
import time
from pathlib import Path

Expand Down Expand Up @@ -319,3 +320,22 @@ def main(
for i in range(reconstruction_manager.size()):
reconstructions[i] = reconstruction_manager.get(i)
return reconstructions


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--database_path", required=True)
parser.add_argument("--image_path", required=True)
parser.add_argument("--input_path", default=None)
parser.add_argument("--output_path", required=True)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
main(
database_path=Path(args.database_path),
image_path=Path(args.image_path),
input_path=Path(args.input_path) if args.input_path else None,
output_path=Path(args.output_path),
)
5 changes: 3 additions & 2 deletions src/colmap/controllers/bundle_adjustment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ void BundleAdjustmentController::Run() {
ba_config.SetConstantCamPositions(*(++reg_image_ids_it), {0}); // 2nd image

// Run bundle adjustment.
BundleAdjuster bundle_adjuster(ba_options, ba_config);
bundle_adjuster.Solve(reconstruction_.get());
std::unique_ptr<BundleAdjuster> bundle_adjuster = CreateDefaultBundleAdjuster(
std::move(ba_options), std::move(ba_config), *reconstruction_);
bundle_adjuster->Solve();

run_timer.PrintMinutes();
}
Expand Down
44 changes: 44 additions & 0 deletions src/colmap/estimators/alignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,50 @@ bool AlignReconstructionToLocations(
return true;
}

bool AlignReconstructionToPosePriors(
const Reconstruction& src_reconstruction,
const std::unordered_map<image_t, PosePrior>& tgt_pose_priors,
const RANSACOptions& ransac_options,
Sim3d* tgt_from_src) {
std::vector<Eigen::Vector3d> src;
std::vector<Eigen::Vector3d> tgt;
src.reserve(tgt_pose_priors.size());
tgt.reserve(tgt_pose_priors.size());

for (const image_t image_id : src_reconstruction.RegImageIds()) {
const auto pose_prior_it = tgt_pose_priors.find(image_id);
if (pose_prior_it != tgt_pose_priors.end() &&
pose_prior_it->second.IsValid()) {
const auto& image = src_reconstruction.Image(image_id);
src.push_back(image.ProjectionCenter());
tgt.push_back(pose_prior_it->second.position);
}
}

if (src.size() < 3) {
LOG(WARNING)
<< "Not enough valid pose priors for PosePrior based alignment!";
return false;
}

if (ransac_options.max_error > 0) {
LORANSAC<SimilarityTransformEstimator<3, true>,
SimilarityTransformEstimator<3, true>>
ransac(ransac_options);

const auto report = ransac.Estimate(src, tgt);

if (report.success) {
*tgt_from_src = Sim3d::FromMatrix(report.model);
return true;
}
} else {
return EstimateSim3d(src, tgt, *tgt_from_src);
}

return false;
}

bool AlignReconstructionsViaReprojections(
const Reconstruction& src_reconstruction,
const Reconstruction& tgt_reconstruction,
Expand Down
15 changes: 11 additions & 4 deletions src/colmap/estimators/alignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,19 @@
namespace colmap {

bool AlignReconstructionToLocations(
const Reconstruction& reconstruction,
const std::vector<std::string>& image_names,
const std::vector<Eigen::Vector3d>& locations,
const Reconstruction& src_reconstruction,
const std::vector<std::string>& tgt_image_names,
const std::vector<Eigen::Vector3d>& tgt_image_locations,
int min_common_images,
const RANSACOptions& ransac_options,
Sim3d* tform);
Sim3d* tgt_from_src);

// TODO: Needs a unit test.
bool AlignReconstructionToPosePriors(
const Reconstruction& src_reconstruction,
const std::unordered_map<image_t, PosePrior>& tgt_pose_priors,
const RANSACOptions& ransac_options,
Sim3d* tgt_from_src);

// Robustly compute alignment between reconstructions by finding images that
// are registered in both reconstructions. The alignment is then estimated
Expand Down
Loading

0 comments on commit d065cea

Please sign in to comment.