Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GBP optimizer #427

Draft
wants to merge 74 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
1918b45
gbp implementation for pose graph problem, euclidean and lie algebra
joeaortiz Mar 16, 2022
6edf81b
Merge branch 'main' into joe.gbp_optimizer
joeaortiz Mar 16, 2022
31a8874
gbp uses exp_map jacobians
joeaortiz Mar 16, 2022
1a6eae4
gaussian for Manifold rather than Variable class
joeaortiz Mar 16, 2022
a7a6e4d
Merge branch 'main' into joe.gbp_optimizer
joeaortiz Mar 17, 2022
e93aa49
uses proper theseus exp_map jacobian
joeaortiz Apr 4, 2022
aa7062e
gaussian class plus marginals and message class
joeaortiz Apr 5, 2022
4479f3b
updated gaussian class
joeaortiz Apr 7, 2022
fc12a98
message scheduler
joeaortiz Apr 8, 2022
264ed22
added mean damping in lin point space
joeaortiz Apr 13, 2022
d5937b4
Merge branch 'main' into joe.gbp_optimizer
joeaortiz Apr 20, 2022
8092a03
fix in linearise and add ba exmple
joeaortiz Apr 20, 2022
106b57d
Merge branch 'main' into joe.gbp_optimizer
joeaortiz Apr 20, 2022
4100342
use th.Manifold gaussian
joeaortiz Apr 22, 2022
b26e1a3
ba tests, fixing numerical issues
joeaortiz Apr 25, 2022
b4c387a
ba viewer
joeaortiz May 5, 2022
b75cfca
bundle adjustment trimesh vis
joeaortiz Jun 1, 2022
40575f4
soft huber-like loss on norm of x,y error
joeaortiz Jun 1, 2022
a91754e
remove prints and fix viewer
joeaortiz Jun 6, 2022
32380fb
remove symmetric check
joeaortiz Jun 6, 2022
2aa58a1
ba visualisation and derivates setup for pgo
joeaortiz Jun 6, 2022
4c91944
lin system damping for ftov msgs
joeaortiz Jun 10, 2022
d1244e8
static dense solver methods
joeaortiz Jun 10, 2022
87cac1e
ba with damping in linear system
joeaortiz Jun 10, 2022
34133be
backward modes
joeaortiz Jun 10, 2022
4263ef5
msgs are class variables to fix implicit backward mode
joeaortiz Jun 10, 2022
7b1bf78
test different backward modes for pgo
joeaortiz Jun 10, 2022
b6f6748
Merge branch 'main' into joe.gbp_optimizer
joeaortiz Jun 10, 2022
6c07c34
fixed copy_impl for reprojection error fn
joeaortiz Jun 11, 2022
d2bd2ae
fix order of args in copy fn
joeaortiz Jun 12, 2022
9f6eb09
used vectorization for part of relin, rename VariableDifference
joeaortiz Jun 12, 2022
78bb120
handles batched problems
joeaortiz Jul 4, 2022
c1b2449
vectorized relinearization and ftov msg passing, schedule class
joeaortiz Jul 5, 2022
c7d9b1c
added missing aux vars to reprojection error cf
joeaortiz Jul 5, 2022
c81210a
Merge branch main into joe.gbp_optimizer
joeaortiz Jul 6, 2022
954a154
vectorized vtof msg passing
joeaortiz Jul 7, 2022
8bcd2cd
handles vectorized inversion with some singular matrices, only comput…
joeaortiz Jul 7, 2022
eaeab00
removed random message schedule
joeaortiz Jul 8, 2022
0700e2e
merge with main
joeaortiz Jul 15, 2022
d8c774a
damping linear system
joeaortiz Jul 18, 2022
ea5173f
local linear damping, fixes for gbp on gpu
joeaortiz Jul 18, 2022
c2507e3
handle loading different format bal file and drop observations
joeaortiz Jul 18, 2022
400091b
gbp check unary factor, fix bug in ba viewer
joeaortiz Jul 20, 2022
9db9612
ba error plot
joeaortiz Jul 20, 2022
12f0b77
fixed are calculation
joeaortiz Jul 22, 2022
255a5b5
tensor for linear system damping, rename message damping
joeaortiz Aug 2, 2022
634ce74
fixes bug where beleifs and factors are created twice
joeaortiz Aug 4, 2022
39be1f1
nesterov, no grad for lm damping, ba batch experiments
joeaortiz Aug 26, 2022
25e87de
nesterov acceleration, two modes
joeaortiz Sep 6, 2022
4e1a4b1
swarm exp
joeaortiz Sep 8, 2022
37f7ed5
learning target for agents
joeaortiz Sep 9, 2022
a6b7274
target character and joint mlp + gbp
joeaortiz Sep 16, 2022
89f0c78
fixed jacobians for gnn factor
joeaortiz Sep 20, 2022
8a75767
implicit backward mode for GBP using GN step
joeaortiz Sep 26, 2022
75c3157
implicit derivatives using gbp and plot backward modes against time
joeaortiz Sep 30, 2022
a7db963
Merge main
joeaortiz Dec 22, 2022
53e4b91
moved into optimizer, removed experiments
joeaortiz Dec 30, 2022
82fc597
Moved import, fixed single wrapper vectorization
joeaortiz Jan 4, 2023
8a01154
update vectorization before truncated steps
joeaortiz Jan 4, 2023
0644e94
Moved bundle adjustment edits to experimental branch
joeaortiz Jan 4, 2023
4a81901
flake8 on github not gitlab
joeaortiz Jan 4, 2023
4489141
Remove nesterov acceleration and timing
joeaortiz Jan 5, 2023
9d5d5f7
End of iter callback, updated mypy version
joeaortiz Jan 6, 2023
748fe38
First attempt at GBP linear solver test
joeaortiz Jan 6, 2023
e597451
Fixed poor conditioning problems with linear test
joeaortiz Jan 6, 2023
c2cebbe
dropout starts later
joeaortiz Jan 6, 2023
d90138f
Merge branch 'main' into joe.gbp_optimizer
joeaortiz Jan 9, 2023
a94225f
Comments and references for understanding GBP code
joeaortiz Jan 13, 2023
b3551fa
Reduced atol threshold for symmetric precision matrix
joeaortiz Jan 13, 2023
aade379
Detach hessian in implicit GBP backward mode
joeaortiz Jan 16, 2023
4c905b9
Fix linearization for truncated, exception for DLM
joeaortiz Jan 16, 2023
4331f71
Fixed bug in linear system damping with vectorization
joeaortiz Jan 17, 2023
633f670
Zero messages correctly when using vectorization
joeaortiz Jan 19, 2023
f9a1c8d
Merged NL optimizer hierarchy refactor
joeaortiz Feb 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/optimizer/linear/test_gbp_linear_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch

import theseus as th


"""
Build linear 1D surface estimation problem.
Solve using GBP and using matrix inversion and compare answers.
GBP exactly computes the marginal means on convergence.

All the following cases should not affect the converged solution:
- with / without vectorization
- with / without factor to variable message damping
- with / without dropout
- with / without factor linear system damping
"""


def _check_info(info, batch_size, max_iterations, initial_error, objective):
assert info.err_history.shape == (batch_size, max_iterations + 1)
assert info.err_history[:, 0].allclose(initial_error)
assert info.err_history.argmin(dim=1).allclose(info.best_iter + 1)
last_error = objective.error_squared_norm() / 2
last_convergence_idx = info.converged_iter.max().item()
assert info.err_history[:, last_convergence_idx].allclose(last_error)


def run_gbp_linear_solver(
frac_loops,
vectorize=True,
ftov_damping=0.0,
dropout=0.0,
lin_system_damping=torch.tensor([1e-4]),
):
max_iterations = 200

n_variables = 100
batch_size = 1

torch.manual_seed(0)

# initial input tensors
# measurements come from x = sin(t / 50) * t**2 / 250 + 1 with random noise added
ts = torch.arange(n_variables)
true_meas = torch.sin(ts / 10.0) * ts * ts / 250.0 + 1
noisy_meas = true_meas[None, :].repeat(batch_size, 1)
noisy_meas += torch.normal(torch.zeros_like(noisy_meas), 1.0)

variables = []
meas_vars = []
for i in range(n_variables):
variables.append(th.Vector(tensor=torch.rand(batch_size, 1), name=f"x_{i}"))
meas_vars.append(th.Vector(tensor=torch.rand(batch_size, 1), name=f"meas_x{i}"))

objective = th.Objective()

# measurement cost functions
meas_weight = th.ScaleCostWeight(5.0, name="meas_weight")
for var, meas in zip(variables, meas_vars):
objective.add(th.Difference(var, meas, meas_weight))

# smoothness cost functions between adjacent variables
smoothness_weight = th.ScaleCostWeight(2.0, name="smoothness_weight")
zero = th.Vector(tensor=torch.zeros(batch_size, 1), name="zero")
for i in range(n_variables - 1):
objective.add(
th.Between(variables[i], variables[i + 1], zero, smoothness_weight)
)

# difference cost functions between non-adjacent variables to give
# off diagonal elements in information matrix
difference_weight = th.ScaleCostWeight(1.0, name="difference_weight")
for i in range(int(n_variables * frac_loops)):
ix1, ix2 = torch.randint(n_variables, (2,))
diff = th.Vector(
tensor=torch.tensor([[true_meas[ix2] - true_meas[ix1]]]), name=f"diff{i}"
)
diff.tensor += torch.normal(torch.zeros(1, 1), 0.2)
objective.add(
th.Between(variables[ix1], variables[ix2], diff, difference_weight)
)

input_tensors = {}
for var in variables:
input_tensors[var.name] = var.tensor
for i in range(len(noisy_meas[0])):
input_tensors[f"meas_x{i}"] = noisy_meas[:, i][:, None]

# Solve with GBP
optimizer = th.GaussianBeliefPropagation(
objective, max_iterations=max_iterations, vectorize=vectorize
)
optimizer.set_params(max_iterations=max_iterations)
objective.update(input_tensors)
initial_error = objective.error_squared_norm() / 2

callback_expected_iter = [0]

def callback(opt_, info_, _, it_):
assert opt_ is optimizer
assert isinstance(info_, th.optimizer.OptimizerInfo)
assert it_ == callback_expected_iter[0]
callback_expected_iter[0] += 1

info = optimizer.optimize(
track_best_solution=True,
track_err_history=True,
end_iter_callback=callback,
ftov_msg_damping=ftov_damping,
dropout=dropout,
lin_system_damping=lin_system_damping,
verbose=True,
)
gbp_solution = [var.tensor.clone() for var in variables]

# Solve with linear solver
objective.update(input_tensors)
linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver)
linear_optimizer.optimize(verbose=True)
lin_solution = [var.tensor.clone() for var in variables]

# Solve with Gauss-Newton
# If problem is poorly conditioned solving with Gauss-Newton can yield
# a slightly different solution to one linear solve, so check both
objective.update(input_tensors)
gn_optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver)
gn_optimizer.optimize(verbose=True)
gn_solution = [var.tensor.clone() for var in variables]

# checks
for x, x_target in zip(gbp_solution, lin_solution):
assert x.allclose(x_target, rtol=1e-3)
for x, x_target in zip(gbp_solution, gn_solution):
assert x.allclose(x_target, rtol=1e-3)
_check_info(info, batch_size, max_iterations, initial_error, objective)

# # Visualise reconstructed surface
# soln_vec = torch.cat(gbp_solution, dim=1)[0]
# import matplotlib.pylab as plt
# plt.scatter(torch.arange(n_variables), soln_vec, label="solution")
# plt.scatter(torch.arange(n_variables), noisy_meas[0], label="meas")
# plt.legend()
# plt.show()


def test_gbp_linear_solver():

# problems with increasing loopyness
# the loopier the fewer iterations to solve
frac_loops = [0.1, 0.2, 0.5]
for frac in frac_loops:

run_gbp_linear_solver(frac_loops=frac)

# with factor to variable message damping, may take too many steps to converge
# run_gbp_linear_solver(vectorize=vectorize, frac_loops=frac, ftov_damping=0.1)
# with dropout
run_gbp_linear_solver(frac_loops=frac, dropout=0.1)

# test linear system damping
run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([0.0]))
run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([1e-2]))
run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([1e-6]))

# test without vectorization once
run_gbp_linear_solver(frac_loops=0.5, vectorize=False)
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
)
from .optimizer import ( # usort: skip
DenseLinearization,
GaussianBeliefPropagation,
Linearization,
ManifoldGaussian,
OptimizerInfo,
Expand Down
13 changes: 13 additions & 0 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def __init__(
# If vectorization is on, this will also handle vectorized containers
self._vectorization_to: Optional[Callable] = None

self.vectorized_cost_fns: Optional[List[CostFunction]] = None
# nested list of name of each base cost function in the vectorized cfs
self.vectorized_cf_names: Optional[List[List[str]]] = None

# If vectorization is on, this gets replaced by a vectorized version
self._retract_method = Objective._retract_base

Expand Down Expand Up @@ -682,6 +686,8 @@ def _enable_vectorization(
vectorization_run_fn: Callable,
vectorized_to: Callable,
vectorized_retract_fn: Callable,
vectorized_cost_fns: List[CostFunction],
vectorized_cf_names: List[List[str]],
error_iter_fn: Callable[[], Iterable[CostFunction]],
enabler: Any,
):
Expand All @@ -694,6 +700,8 @@ def _enable_vectorization(
self._vectorization_run = vectorization_run_fn
self._vectorization_to = vectorized_to
self._retract_method = vectorized_retract_fn
self.vectorized_cost_fns = vectorized_cost_fns
self.vectorized_cf_names = vectorized_cf_names
self._get_error_iter = error_iter_fn
self._vectorized = True

Expand All @@ -703,6 +711,8 @@ def disable_vectorization(self):
self._vectorization_run = None
self._vectorization_to = None
self._retract_method = Objective._retract_base
self.vectorized_cost_fns = None
self.vectorized_cf_names = None
self._get_error_iter = self._get_error_iter_base
self._vectorized = False

Expand All @@ -713,6 +723,9 @@ def vectorized(self):
== (self._vectorized_jacobians_iter is None)
== (self._vectorization_run is None)
== (self._vectorization_to is None)
== (self._retract_method is Objective._retract_base)
== (self.vectorized_cost_fns is None)
== (self.vectorized_cf_names is None)
== (self._get_error_iter == self._get_error_iter_base)
== (self._retract_method == Objective._retract_base)
)
Expand Down
14 changes: 10 additions & 4 deletions theseus/core/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,17 @@ def __init__(self, objective: Objective, empty_cuda_cache: bool = False):
_CostFunctionSchema, List[_CostFunctionWrapper]
] = defaultdict(list)

schema_cf_names_dict: Dict[_CostFunctionSchema, List[str]] = defaultdict(list)

# Create wrappers for all cost functions and also get their schemas
for cost_fn in objective.cost_functions.values():
wrapper = _CostFunctionWrapper(cost_fn)
self._cost_fn_wrappers.append(wrapper)
schema = _get_cost_function_schema(cost_fn)
self._schema_dict[schema].append(wrapper)

schema_cf_names_dict[schema].append(cost_fn.name)

# Now create a vectorized cost function for each unique schema
self._vectorized_cost_fns: Dict[_CostFunctionSchema, CostFunction] = {}
for schema in self._schema_dict:
Expand All @@ -146,6 +150,8 @@ def __init__(self, objective: Objective, empty_cuda_cache: bool = False):
self._vectorize,
self._to,
self._vectorized_retract_optim_vars,
list(self._vectorized_cost_fns.values()),
list(schema_cf_names_dict.values()),
self._get_vectorized_error_iter,
self,
)
Expand Down Expand Up @@ -391,10 +397,10 @@ def _vectorize(
}
ret = [cf for cf_list in schema_dict.values() for cf in cf_list]
for schema, cost_fn_wrappers in schema_dict.items():
if len(cost_fn_wrappers) == 1:
self._handle_singleton_wrapper(schema, cost_fn_wrappers, mode)
else:
self._handle_schema_vectorization(schema, cost_fn_wrappers, mode)
# if len(cost_fn_wrappers) == 1:
# self._handle_singleton_wrapper(schema, cost_fn_wrappers, mode)
# else:
self._handle_schema_vectorization(schema, cost_fn_wrappers, mode)
return ret

def _get_vectorized_error_iter(self) -> Iterable[_CostFunctionWrapper]:
Expand Down
1 change: 1 addition & 0 deletions theseus/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .optimizer import Optimizer, OptimizerInfo
from .sparse_linearization import SparseLinearization
from .variable_ordering import VariableOrdering
from .gbp import GaussianBeliefPropagation, GBPSchedule
Loading