diff --git a/q2_gglasso/_summarize/_visualizer.py b/q2_gglasso/_summarize/_visualizer.py index 838fc0e..bbdcf01 100644 --- a/q2_gglasso/_summarize/_visualizer.py +++ b/q2_gglasso/_summarize/_visualizer.py @@ -50,7 +50,7 @@ def _get_colors(df: pd.DataFrame()): return color_list, colors -def _get_labels(solution: zarr.hierarchy.Group): +def _get_labels(solution: zarr.hierarchy.Group, clustered: bool = False, clust_order: list = None): labels_dict = dict() labels_dict_reversed = dict() p = np.array(solution['p']).item() @@ -63,24 +63,32 @@ def _get_labels(solution: zarr.hierarchy.Group): def _get_order(data: pd.DataFrame, method: str = 'average', metric: str = 'euclidean'): + """ + Performs hierarchical clustering on the input DataFrame and returns the cluster order. + + Args: + data (pd.DataFrame): The input DataFrame. + method (str, optional): The clustering method. Defaults to 'average'. + metric (str, optional): The distance metric. Defaults to 'euclidean'. + + Returns: + list: The cluster order as a list of indices. + """ grid = sns.clustermap(data, method=method, metric=metric, robust=True) plt.close() + clust_order = grid.dendrogram_row.reordered_ind - row_order = grid.dendrogram_row.reordered_ind - col_order = grid.dendrogram_col.reordered_ind + return clust_order - return row_order, col_order - -def hierarchical_clustering(data: pd.DataFrame, row_order: list, column_order: list, - n_covariates: int = None): +def hierarchical_clustering(data: pd.DataFrame, clust_order: list, n_covariates: int = None): if n_covariates is None: - re_data = data.iloc[row_order, column_order] + re_data = data.iloc[clust_order, clust_order] else: asv_part = data.iloc[:-n_covariates, :-n_covariates] - re_asv_part = asv_part.iloc[row_order, column_order] - cov_asv_part = data.iloc[:-n_covariates, -n_covariates:].iloc[row_order, :] + re_asv_part = asv_part.iloc[clust_order, clust_order] + cov_asv_part = data.iloc[:-n_covariates, -n_covariates:].iloc[clust_order, :] cov_part = data.iloc[-n_covariates:, -n_covariates:] res = np.block([[re_asv_part.values, cov_asv_part.values], @@ -100,8 +108,11 @@ def _make_heatmap(data: pd.DataFrame(), title: str = None, labels_dict: dict = N shifted_labels_dict = {k + 0.5: v for k, v in labels_dict.items()} shifted_labels_dict_reversed = {k + 0.5: v for k, v in labels_dict_reversed.items()} - df = data.iloc[::-1] # rotate matrix 90 degrees - df = pd.DataFrame(df.stack(), columns=['covariance']).reset_index() + # data.rename(index=labels_dict, inplace=True) + # data.rename(columns=labels_dict, inplace=True) + + # df = data.iloc[::-1] # rotate matrix 90 degrees + df = pd.DataFrame(data.stack(), columns=['covariance']).reset_index() df.columns = ["taxa_y", "taxa_x", "covariance"] df = df.replace({"taxa_x": labels_dict, "taxa_y": labels_dict}) @@ -148,7 +159,6 @@ def _make_heatmap(data: pd.DataFrame(), title: str = None, labels_dict: dict = N return p - # def _make_heatmap(data: pd.DataFrame(), title: str = None, labels_dict: dict = None, # labels_dict_reversed: dict = None, # width: int = 1500, height: int = 1500, label_size: str = "5pt", @@ -242,20 +252,29 @@ def _make_stats(solution: zarr.hierarchy.Group, labels_dict: dict = None): return l1 -def _solution_plot(solution: zarr.hierarchy.Group, width: int, height: int, label_size: str): +def _solution_plot(solution: zarr.hierarchy.Group, width: int, height: int, label_size: str, + clustered: bool = False, n_cov: int = None): tabs = [] - labels_dict, labels_dict_reversed = _get_labels(solution=solution) + labels_dict, labels_dict_reversed = _get_labels(solution=solution, clustered=False) + + sample_covariance = pd.DataFrame(solution['covariance']) + # rotate diagonal + precision = pd.DataFrame(solution['solution/precision_']) + + # if clustered: + # clust_order = _get_order(sample_covariance, method='average', metric='euclidean') + # sample_covariance = hierarchical_clustering(sample_covariance, clust_order=clust_order, + # n_covariates=n_cov) + # precision = hierarchical_clustering(precision, clust_order=clust_order, n_covariates=n_cov) - sample_covariance = pd.DataFrame(solution['covariance']).iloc[::-1] - p1 = _make_heatmap(data=sample_covariance, title="Sample covariance", width=width, - height=height, + p1 = _make_heatmap(data=sample_covariance, title="Sample covariance", + width=width, height=height, label_size=label_size, labels_dict=labels_dict, labels_dict_reversed=labels_dict_reversed) tab1 = Panel(child=row(p1), title="Sample covariance") tabs.append(tab1) # due to inversion we multiply the result by -1 to keep the original color scheme - precision = pd.DataFrame(solution['solution/precision_']).iloc[::-1] p2 = _make_heatmap(data=-1 * precision, labels_dict=labels_dict, labels_dict_reversed=labels_dict_reversed, title="Estimated (negative) inverse covariance", width=width, height=height, @@ -264,7 +283,7 @@ def _solution_plot(solution: zarr.hierarchy.Group, width: int, height: int, labe tabs.append(tab2) try: - low_rank = pd.DataFrame(solution['solution/lowrank_']).iloc[::-1] + low_rank = pd.DataFrame(solution['solution/lowrank_']) p3 = _make_heatmap(data=low_rank, labels_dict=labels_dict, labels_dict_reversed=labels_dict_reversed, title="Low-rank", not_low_rank=False, width=width, height=height,