Skip to content

Commit

Permalink
Add log option to plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederike Duembgen committed Jan 15, 2024
1 parent abdd3cb commit 9589b87
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions poly_matrix/poly_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ def __getitem__(self, key):
def get_size(self):
return max(self.last_var_i_index, self.last_var_j_index)

def rename(self, from_labels, to_label):
"""Rename all keys in from_labels to to_label, summing up the respective values."""
for from_label in from_labels:
key_j_list = list(self.matrix[from_label].keys())
for key_j in key_j_list:
if key_j in from_labels:
key_j_to = to_label
else:
key_j_to = key_j
self[to_label, key_j_to] += self[from_label, key_j]
self.drop(from_labels)

def add_variable_i(self, key, size):
self.variable_dict_i[key] = size
self.last_var_i_index += size
Expand Down Expand Up @@ -531,8 +543,8 @@ def get_matrix_sparse(self, variables=None, output_type="coo", verbose=False):
rows, cols = np.nonzero(values)
i_list = np.append(i_list, rows + indices_i[key_i])
j_list = np.append(j_list, cols + indices_j[key_j])
data_list = np.append(data_list, values[rows,cols])
data_list = np.append(data_list, values[rows, cols])

if verbose:
print(f"Filling took {time.time() - t1:.2}s.")

Expand Down Expand Up @@ -636,6 +648,7 @@ def _plot_matrix(
variables_i=None,
variables_j=None,
reduced_ticks=False,
log=False,
**kwargs,
):
if type(variables_i) is dict:
Expand Down Expand Up @@ -663,10 +676,15 @@ def _plot_matrix(
raise ValueError("untreated case!")

mat = self.get_matrix(variables=(variables_i, variables_j))

if plot_type == "sparse":
im = ax.spy(mat, **kwargs)
elif plot_type == "dense":
im = ax.matshow(mat.toarray(), **kwargs)
if log:
mat_plot = np.log10(np.abs(mat.toarray()))
else:
mat_plot = mat.toarray()
im = ax.matshow(mat_plot, **kwargs)
else:
raise ValueError(plot_type)

Expand Down Expand Up @@ -708,6 +726,7 @@ def matshow(
variables_i=None,
variables_j=None,
ax=None,
log=False,
**kwargs,
):
if ax is None:
Expand All @@ -720,6 +739,7 @@ def matshow(
variables=variables,
variables_i=variables_i,
variables_j=variables_j,
log=log,
**kwargs,
)
return fig, ax, im
Expand Down

0 comments on commit 9589b87

Please sign in to comment.