Skip to content

Commit

Permalink
Formattin
Browse files Browse the repository at this point in the history
  • Loading branch information
duembgen committed Jan 10, 2024
1 parent 14ee250 commit 5fa534a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
""" Usage example for PolyMatrix class. """
""" Usage example for PolyMatrix class. """

from poly_matrix import PolyMatrix
import numpy as np
Expand Down
30 changes: 19 additions & 11 deletions poly_matrix/poly_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def unroll(var_dict):
var_dict_unrolled[f"{key}:{j}"] = 1
return var_dict_unrolled


def augment(var_dict):
"""Create new dict to make conversion from sparse (indexed by 0 to N-1)
to polymatrix (indexed by var_dict) easier.
Expand Down Expand Up @@ -239,10 +240,14 @@ def __setitem__(self, key_pair, val, symmetric=None):
# make sure the dimensions of new block are consistent with
# previously inserted blocks.
if key_i in self.adjacency_i.keys():
assert val.shape[0] == self.variable_dict_i[key_i], f"mismatch in height of filled value for key_i {key_i}: got {val.shape[0]} but expected {self.variable_dict_i[key_i]}"
assert (
val.shape[0] == self.variable_dict_i[key_i]
), f"mismatch in height of filled value for key_i {key_i}: got {val.shape[0]} but expected {self.variable_dict_i[key_i]}"

if key_j in self.adjacency_j.keys():
assert val.shape[1] == self.variable_dict_j[key_j], f"mismatch in width of filled value for key_j {key_j}: {val.shape[1]} but expected {self.variable_dict_j[key_j]}"
assert (
val.shape[1] == self.variable_dict_j[key_j]
), f"mismatch in width of filled value for key_j {key_j}: {val.shape[1]} but expected {self.variable_dict_j[key_j]}"
self.add_key_pair(key_i, key_j)

if key_i == key_j:
Expand Down Expand Up @@ -543,8 +548,8 @@ def get_matrix_sparse(self, variables=None, output_type="coo", verbose=False):
rows, cols = np.nonzero(values)
i_list = np.append(i_list, rows + indices_i[key_i])
j_list = np.append(j_list, cols + indices_j[key_j])
data_list = np.append(data_list, values[rows,cols])
data_list = np.append(data_list, values[rows, cols])

if verbose:
print(f"Filling took {time.time() - t1:.2}s.")

Expand Down Expand Up @@ -697,7 +702,7 @@ def _plot_matrix(
tick_locs += [first + i for i in range(sz)]
if sz > 1:
if reduced_ticks:
tick_lbls += [f"{var}"] + ["" for i in range(sz-1)]
tick_lbls += [f"{var}"] + ["" for i in range(sz - 1)]
else:
tick_lbls += [f"{var}:{i}" for i in range(sz)]
else:
Expand All @@ -706,9 +711,7 @@ def _plot_matrix(
tick_fun(ticks=tick_locs, labels=tick_lbls, fontsize=10)
return im

def spy(
self, variables: dict = None, variables_i=None, variables_j=None, **kwargs
):
def spy(self, variables: dict = None, variables_i=None, variables_j=None, **kwargs):
fig, ax = plt.subplots()
im = self._plot_matrix(
plot_type="sparse",
Expand All @@ -721,7 +724,12 @@ def spy(
return fig, ax, im

def matshow(
self, variables: dict = None, variables_i=None, variables_j=None, ax=None, **kwargs
self,
variables: dict = None,
variables_i=None,
variables_j=None,
ax=None,
**kwargs,
):
if ax is None:
fig, ax = plt.subplots()
Expand Down Expand Up @@ -819,8 +827,8 @@ def __sub__(self, other):
return self + (other * (-1))

def __div__(self, scalar):
""" overload M / a, for some reason this has no effect"""
return self * (1/scalar)
"""overload M / a, for some reason this has no effect"""
return self * (1 / scalar)

def __rmul__(self, scalar, inplace=False):
"""Overload a * M"""
Expand Down

0 comments on commit 5fa534a

Please sign in to comment.