diff --git a/poly_matrix/poly_matrix.py b/poly_matrix/poly_matrix.py index fee55a2..e17bfa3 100644 --- a/poly_matrix/poly_matrix.py +++ b/poly_matrix/poly_matrix.py @@ -184,6 +184,18 @@ def __getitem__(self, key): def get_size(self): return max(self.last_var_i_index, self.last_var_j_index) + def rename(self, from_labels, to_label): + """Rename all keys in from_labels to to_label, summing up the respective values.""" + for from_label in from_labels: + key_j_list = list(self.matrix[from_label].keys()) + for key_j in key_j_list: + if key_j in from_labels: + key_j_to = to_label + else: + key_j_to = key_j + self[to_label, key_j_to] += self[from_label, key_j] + self.drop(from_labels) + def add_variable_i(self, key, size): self.variable_dict_i[key] = size self.last_var_i_index += size @@ -531,8 +543,8 @@ def get_matrix_sparse(self, variables=None, output_type="coo", verbose=False): rows, cols = np.nonzero(values) i_list = np.append(i_list, rows + indices_i[key_i]) j_list = np.append(j_list, cols + indices_j[key_j]) - data_list = np.append(data_list, values[rows,cols]) - + data_list = np.append(data_list, values[rows, cols]) + if verbose: print(f"Filling took {time.time() - t1:.2}s.") @@ -636,6 +648,7 @@ def _plot_matrix( variables_i=None, variables_j=None, reduced_ticks=False, + log=False, **kwargs, ): if type(variables_i) is dict: @@ -663,10 +676,15 @@ def _plot_matrix( raise ValueError("untreated case!") mat = self.get_matrix(variables=(variables_i, variables_j)) + if plot_type == "sparse": im = ax.spy(mat, **kwargs) elif plot_type == "dense": - im = ax.matshow(mat.toarray(), **kwargs) + if log: + mat_plot = np.log10(np.abs(mat.toarray())) + else: + mat_plot = mat.toarray() + im = ax.matshow(mat_plot, **kwargs) else: raise ValueError(plot_type) @@ -708,6 +726,7 @@ def matshow( variables_i=None, variables_j=None, ax=None, + log=False, **kwargs, ): if ax is None: @@ -720,6 +739,7 @@ def matshow( variables=variables, variables_i=variables_i, variables_j=variables_j, + log=log, **kwargs, ) return fig, ax, im