Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
duembgen committed Jan 11, 2024
1 parent ceb0d3c commit 7ce28d3
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 24 deletions.
3 changes: 2 additions & 1 deletion _scripts/run_other_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
DEBUG = False

RESULTS_DIR = "_results"
#RESULTS_DIR = "_results_server"
# RESULTS_DIR = "_results_server"


def lifter_tightness(
Lifter=MonoLifter, robust: bool = False, d: int = 2, n_landmarks=4, n_outliers=0
Expand Down
13 changes: 10 additions & 3 deletions examples/robust_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@ def get_problem(robust=True):
# Create lifter
np.random.seed(0)
n_landmarks = 4
d= 3
d = 3
if robust:
n_outliers = 1
else:
n_outliers = 0

lifter = Lifter(d=d, n_landmarks=n_landmarks+n_outliers, robust=robust, n_outliers=n_outliers, level=level, variable_list=None)
lifter = Lifter(
d=d,
n_landmarks=n_landmarks + n_outliers,
robust=robust,
n_outliers=n_outliers,
level=level,
variable_list=None,
)
Q, y = lifter.get_Q()

from auto_template.learner import Learner
Expand Down Expand Up @@ -71,4 +78,4 @@ def plot_problem(prob, lifter, fname=""):
prob, lifter = get_problem(robust=robust)
fname = f"certifiable-tools/_examples/test_prob_{number}G.pkl"
save_test_problem(**prob, fname=fname)
#plot_problem(prob, lifter, fname=fname.replace(".pkl", ".png"))
# plot_problem(prob, lifter, fname=fname.replace(".pkl", ".png"))
4 changes: 2 additions & 2 deletions lifters/state_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import numpy as np
import scipy.sparse as sp
from cert_tools.linalg_tools import get_nullspace

from lifters.base_class import BaseClass
from poly_matrix import PolyMatrix, unroll
from utils.common import upper_triangular

from poly_matrix import PolyMatrix, unroll


def ravel_multi_index_triu(index_tuple, shape):
"""Equivalent of np.multi_index_triu, but using only the upper-triangular part of matrix."""
Expand Down
4 changes: 2 additions & 2 deletions lifters/stereo_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class StereoLifter(StateLifter, ABC):
["h", "x"],
["h", "z_0"],
["h", "x", "z_0"],
["h", "z_0", "z_1"], # should achieve tightness here
# ["h", "x", "z_0", "z_1"],
["h", "z_0", "z_1"], # should achieve tightness here
# ["h", "x", "z_0", "z_1"],
# ["h", "z_0", "z_1", "z_2"],
]

Expand Down
28 changes: 14 additions & 14 deletions lifters/wahba_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ def get_Q(self, noise: float = None):

def get_Q_from_y(self, y):
"""
every cost term can be written as
(1 + wi)/b^2 r^2(x, zi) + (1 - wi)
residual term:
(Rpi + t - ui).T Wi (Rpi + t - ui) =
[t', vec(R)'] @ [I (pi x I)]' @ Wi @ [I (pi x I)] @ [t ; vec(R)]
------x'----- -----Pi'-----
- 2 [t', vec(R)'] @ [I (pi x I)]' Wi @ ui
-----x'------ ---------Pi_xl--------
+ ui.T @ Wi @ ui
-----Pi_ll------
every cost term can be written as
(1 + wi)/b^2 r^2(x, zi) + (1 - wi)
residual term:
(Rpi + t - ui).T Wi (Rpi + t - ui) =
[t', vec(R)'] @ [I (pi x I)]' @ Wi @ [I (pi x I)] @ [t ; vec(R)]
------x'----- -----Pi'-----
- 2 [t', vec(R)'] @ [I (pi x I)]' Wi @ ui
-----x'------ ---------Pi_xl--------
+ ui.T @ Wi @ ui
-----Pi_ll------
"""
from poly_matrix.poly_matrix import PolyMatrix

Expand All @@ -107,9 +107,9 @@ def get_Q_from_y(self, y):
ui = y[i]
Pi = np.c_[np.eye(self.d), np.kron(pi, np.eye(self.d))]

Pi_ll = ui.T @ Wi @ ui
Pi_xl = -(Pi.T @ Wi @ ui)[:, None]
Qi = Pi.T @ Wi @ Pi
Pi_ll = ui.T @ Wi @ ui
Pi_xl = -(Pi.T @ Wi @ ui)[:, None]
Qi = Pi.T @ Wi @ Pi
if NORMALIZE:
Pi_ll /= norm
Pi_xl /= norm
Expand Down
2 changes: 1 addition & 1 deletion poly_matrix
1 change: 0 additions & 1 deletion solvers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def solve_sdp_cvxpy(
return X, info



def solve_sdp_cvxpy_new(
Q,
A_b_list,
Expand Down

0 comments on commit 7ce28d3

Please sign in to comment.