Skip to content

Commit

Permalink
fixed matrix rotation for heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
Vlasovets committed Dec 15, 2023
1 parent 5642f76 commit cc08e2e
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions q2_gglasso/_summarize/_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_colors(df: pd.DataFrame()):
return color_list, colors


def _get_labels(solution: zarr.hierarchy.Group):
def _get_labels(solution: zarr.hierarchy.Group, clustered: bool = False, clust_order: list = None):
labels_dict = dict()
labels_dict_reversed = dict()
p = np.array(solution['p']).item()
Expand All @@ -63,24 +63,32 @@ def _get_labels(solution: zarr.hierarchy.Group):


def _get_order(data: pd.DataFrame, method: str = 'average', metric: str = 'euclidean'):
"""
Performs hierarchical clustering on the input DataFrame and returns the cluster order.
Args:
data (pd.DataFrame): The input DataFrame.
method (str, optional): The clustering method. Defaults to 'average'.
metric (str, optional): The distance metric. Defaults to 'euclidean'.
Returns:
list: The cluster order as a list of indices.
"""
grid = sns.clustermap(data, method=method, metric=metric, robust=True)
plt.close()
clust_order = grid.dendrogram_row.reordered_ind

row_order = grid.dendrogram_row.reordered_ind
col_order = grid.dendrogram_col.reordered_ind
return clust_order

return row_order, col_order


def hierarchical_clustering(data: pd.DataFrame, row_order: list, column_order: list,
n_covariates: int = None):
def hierarchical_clustering(data: pd.DataFrame, clust_order: list, n_covariates: int = None):
if n_covariates is None:
re_data = data.iloc[row_order, column_order]
re_data = data.iloc[clust_order, clust_order]

else:
asv_part = data.iloc[:-n_covariates, :-n_covariates]
re_asv_part = asv_part.iloc[row_order, column_order]
cov_asv_part = data.iloc[:-n_covariates, -n_covariates:].iloc[row_order, :]
re_asv_part = asv_part.iloc[clust_order, clust_order]
cov_asv_part = data.iloc[:-n_covariates, -n_covariates:].iloc[clust_order, :]
cov_part = data.iloc[-n_covariates:, -n_covariates:]

res = np.block([[re_asv_part.values, cov_asv_part.values],
Expand All @@ -100,8 +108,11 @@ def _make_heatmap(data: pd.DataFrame(), title: str = None, labels_dict: dict = N
shifted_labels_dict = {k + 0.5: v for k, v in labels_dict.items()}
shifted_labels_dict_reversed = {k + 0.5: v for k, v in labels_dict_reversed.items()}

df = data.iloc[::-1] # rotate matrix 90 degrees
df = pd.DataFrame(df.stack(), columns=['covariance']).reset_index()
# data.rename(index=labels_dict, inplace=True)
# data.rename(columns=labels_dict, inplace=True)

# df = data.iloc[::-1] # rotate matrix 90 degrees
df = pd.DataFrame(data.stack(), columns=['covariance']).reset_index()
df.columns = ["taxa_y", "taxa_x", "covariance"]
df = df.replace({"taxa_x": labels_dict, "taxa_y": labels_dict})

Expand Down Expand Up @@ -148,7 +159,6 @@ def _make_heatmap(data: pd.DataFrame(), title: str = None, labels_dict: dict = N

return p


# def _make_heatmap(data: pd.DataFrame(), title: str = None, labels_dict: dict = None,
# labels_dict_reversed: dict = None,
# width: int = 1500, height: int = 1500, label_size: str = "5pt",
Expand Down Expand Up @@ -242,20 +252,29 @@ def _make_stats(solution: zarr.hierarchy.Group, labels_dict: dict = None):
return l1


def _solution_plot(solution: zarr.hierarchy.Group, width: int, height: int, label_size: str):
def _solution_plot(solution: zarr.hierarchy.Group, width: int, height: int, label_size: str,
clustered: bool = False, n_cov: int = None):
tabs = []
labels_dict, labels_dict_reversed = _get_labels(solution=solution)
labels_dict, labels_dict_reversed = _get_labels(solution=solution, clustered=False)

sample_covariance = pd.DataFrame(solution['covariance'])
# rotate diagonal
precision = pd.DataFrame(solution['solution/precision_'])

# if clustered:
# clust_order = _get_order(sample_covariance, method='average', metric='euclidean')
# sample_covariance = hierarchical_clustering(sample_covariance, clust_order=clust_order,
# n_covariates=n_cov)
# precision = hierarchical_clustering(precision, clust_order=clust_order, n_covariates=n_cov)

sample_covariance = pd.DataFrame(solution['covariance']).iloc[::-1]
p1 = _make_heatmap(data=sample_covariance, title="Sample covariance", width=width,
height=height,
p1 = _make_heatmap(data=sample_covariance, title="Sample covariance",
width=width, height=height,
label_size=label_size, labels_dict=labels_dict,
labels_dict_reversed=labels_dict_reversed)
tab1 = Panel(child=row(p1), title="Sample covariance")
tabs.append(tab1)

# due to inversion we multiply the result by -1 to keep the original color scheme
precision = pd.DataFrame(solution['solution/precision_']).iloc[::-1]
p2 = _make_heatmap(data=-1 * precision, labels_dict=labels_dict,
labels_dict_reversed=labels_dict_reversed,
title="Estimated (negative) inverse covariance", width=width, height=height,
Expand All @@ -264,7 +283,7 @@ def _solution_plot(solution: zarr.hierarchy.Group, width: int, height: int, labe
tabs.append(tab2)

try:
low_rank = pd.DataFrame(solution['solution/lowrank_']).iloc[::-1]
low_rank = pd.DataFrame(solution['solution/lowrank_'])
p3 = _make_heatmap(data=low_rank, labels_dict=labels_dict,
labels_dict_reversed=labels_dict_reversed,
title="Low-rank", not_low_rank=False, width=width, height=height,
Expand Down

0 comments on commit cc08e2e

Please sign in to comment.