Skip to content

Commit

Permalink
Fix ls_problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederike Duembgen committed Jul 8, 2023
1 parent 26da057 commit 4854a0a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
27 changes: 24 additions & 3 deletions poly_matrix/least_squares_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ def __init__(self):

def get_B_matrix(self, variables, output_type="csc"):
return self.B.get_matrix(
variables=({m: 1 for m in range(self.m)}, variables),
variables=([m for m in range(self.m)], variables),
output_type=output_type,
)

def get_Q(self):
if self.Q is None:
self.Q = self.B.transpose().multiply(self.B)
# self.Q.get_matrix(variables=variables, output_type=output_type)
return self.Q

def add_residual(self, res_dict: dict):
Expand All @@ -39,4 +38,26 @@ def add_residual(self, res_dict: dict):
for key, val in res_dict.items():
self.B[self.m, key] += val
self.m += 1
return
return

# old implementation directly constructs Q.
for diag, val in res_dict.items():
# forbid 1-dimensional arrays cause they are ambiguous.
assert np.ndim(val) in [
0,
2,
]
if np.ndim(val) == 0:
self[diag, diag] += val**2
else:
self[diag, diag] += val @ val.T

for off_diag_pair in itertools.combinations(res_dict.items(), 2):
dict0, dict1 = off_diag_pair

if np.ndim(dict1[1]) > 0:
new_val = dict0[1] * dict1[1].T
else:
new_val = dict0[1] * dict1[1]
# new value is an array:
self[dict0[0], dict1[0]] += new_val
14 changes: 7 additions & 7 deletions poly_matrix/poly_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,10 @@ def __setitem__(self, key_pair, val, symmetric=None):
# make sure the dimensions of new block are consistent with
# previously inserted blocks.
if key_i in self.adjacency_i.keys():
assert (
val.shape[0] == self.variable_dict_i[key_i]
), f"mismatch in height of filled value for key_i {key_i}: got {val.shape[0]} but expected {self.variable_dict_i[key_i]}"
assert val.shape[0] == self.variable_dict_i[key_i], f"mismatch in height of filled value for key_i {key_i}: got {val.shape[0]} but expected {self.variable_dict_i[key_i]}"

if key_j in self.adjacency_j.keys():
assert (
val.shape[1] == self.variable_dict_j[key_j],
), f"mismatch in width of filled value for key_j {key_j}: {val.shape[1]} but expected {self.variable_dict_j[key_j]}"
assert val.shape[1] == self.variable_dict_j[key_j], f"mismatch in width of filled value for key_j {key_j}: {val.shape[1]} but expected {self.variable_dict_j[key_j]}"
self.add_key_pair(key_i, key_j)

if key_i == key_j:
Expand Down Expand Up @@ -542,7 +538,7 @@ def get_matrix_sparse(self, variables=None, output_type="coo", verbose=False):
assert values.shape == (
variable_dict["i"][key_i],
variable_dict["j"][key_j],
), f"Variable size does not match input matrix size, variables: {(variable_dict_i[key_i],variable_dict_j[key_j])}, matrix: {values.shape}"
), f"Variable size does not match input matrix size, variables: {(variable_dict['i'][key_i],variable_dict['j'][key_j])}, matrix: {values.shape}"
# generate list of indices for sparse mat input
rows, cols = np.nonzero(values)
i_list = np.append(i_list, rows + indices_i[key_i])
Expand Down Expand Up @@ -822,6 +818,10 @@ def __add__(self, other, inplace=False):
def __sub__(self, other):
return self + (other * (-1))

def __div__(self, scalar):
""" overload M / a, for some reason this has no effect"""
return self * (1/scalar)

def __rmul__(self, scalar, inplace=False):
"""Overload a * M"""
return self.__mul__(scalar, inplace)
Expand Down

0 comments on commit 4854a0a

Please sign in to comment.