Skip to content

Commit

Permalink
Merge pull request #27 from AllenInstitute/js/kde
Browse files Browse the repository at this point in the history
feat: use a KDE for smooth scores over tissue area
  • Loading branch information
berl authored Dec 12, 2024
2 parents 1cc4391 + ac77c30 commit edfe680
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 78 deletions.
12 changes: 1 addition & 11 deletions spatial_compare/spatial_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from spatial_compare.utils import grouped_obs_mean


DEFAULT_DATA_NAMES = ["Data 0", "Data 1"]
TARGET_LEGEND_MARKER_SIZE = 20

Expand Down Expand Up @@ -550,7 +549,6 @@ def compare_expression(
"--",
)

print(gene_ratio_dfs.keys())
if len(gene_ratio_dfs.keys()) > 0:
gene_ratio_df = pd.concat(gene_ratio_dfs, axis=1)
else:
Expand All @@ -562,10 +560,7 @@ def compare_expression(
"gene_ratio_dataframe": gene_ratio_df,
}

def plot_detection_ratio(
self, gene_ratio_dataframe, figsize=[15, 15], filtred=True
):

def plot_detection_ratio(self, gene_ratio_dataframe, figsize=[15, 15]):
detection_ratio_plots(
gene_ratio_dataframe,
data_names=self.data_names,
Expand All @@ -574,7 +569,6 @@ def plot_detection_ratio(
)

def spatial_compare(self, **kwargs):

if "category" in kwargs.keys():
self.set_category(kwargs["category"])

Expand Down Expand Up @@ -899,7 +893,6 @@ def filter_and_cluster_twice(
if isinstance(input_ad.X, sparse.csr.csr_matrix):
input_ad.X = input_ad.X.toarray()

print("converted to array ")
if "gene" in input_ad.var.columns:
low_detection_genes = input_ad.var.iloc[
np.nonzero(np.max(input_ad.X, axis=0) <= min_max_counts)
Expand Down Expand Up @@ -929,7 +922,6 @@ def filter_and_cluster_twice(

if run_preprocessing:
# Normalizing to median total counts
print("normalizing total counts")
sc.pp.normalize_total(to_cluster)
# Logarithmize the data
sc.pp.log1p(to_cluster)
Expand All @@ -952,7 +944,6 @@ def filter_and_cluster_twice(

# per cluster, repeat PCA and clustering...
# this duplicates the data! but it's necessary because the scanpy functions would create a copy of the subset anyway(?)
print("2nd round of clustering")
all_subs = []
for cl in to_cluster.obs.leiden_0.unique():
subcopy = to_cluster[to_cluster.obs.leiden_0 == cl, :].copy()
Expand All @@ -967,7 +958,6 @@ def filter_and_cluster_twice(
]

all_subs.append(subcopy)
print("concatenating")
iterative_clusters_ad = ad.concat(all_subs)

return iterative_clusters_ad
Expand Down
248 changes: 181 additions & 67 deletions spatial_compare/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
import scipy as sp


def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None):
Expand Down Expand Up @@ -48,16 +49,109 @@ def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None):
return out


def spatial_detection_score_kde(query: pd.DataFrame, grid_out: int = 100):
cell_x = query["x_centroid"]
cell_y = query["y_centroid"]
cell_coords = np.vstack([cell_x.values, cell_y.values])

xmin, xmax = cell_x.min(), cell_x.max()
ymin, ymax = cell_y.min(), cell_y.max()
extent = [xmin, xmax, ymin, ymax]

if grid_out == 0:
positions = cell_coords
else:
X, Y = np.mgrid[
xmin : xmax : complex(0, grid_out), ymin : ymax : complex(0, grid_out)
]
positions = np.vstack([X.ravel(), Y.ravel()])

scores = []
for column in [
"detection_relative_z_score",
"detection_difference",
"log_10_detection_ratio",
]:
weights = query[column].values
weights[np.isnan(weights)] = 0

wext = [weights.min(), weights.max()]

weights = weights + abs(weights.min())

kde_weighted = sp.stats.gaussian_kde(cell_coords, weights=weights)
kde_unweighted = sp.stats.gaussian_kde(
cell_coords,
)

if column == "detection_relative_z_score":
estimator = (
lambda positions: kde_weighted(positions).T
/ kde_unweighted(positions).T
)

Z = kde_weighted(positions).T / kde_unweighted(positions).T
Z = (Z - Z.min()) / (Z.max() - Z.min()) * (wext[1] - wext[0]) + wext[0]
if grid_out != 0:
Z = np.reshape(Z, X.shape).T

scores.append(Z)

return (estimator, extent, scores[0], scores[1], scores[2])


def spatial_detection_score_binned(
query: pd.DataFrame,
n_bins: int = 50,
):
query["xy_bucket"] = list(
zip(
pd.cut(query.x_centroid, n_bins, labels=list(range(n_bins))),
pd.cut(query.y_centroid, n_bins, labels=list(range(n_bins))),
)
)

binx = query.groupby("xy_bucket").x_centroid.mean()
biny = query.groupby("xy_bucket").y_centroid.mean()

z_score = query.groupby("xy_bucket").detection_relative_z_score.mean()
difference = query.groupby("xy_bucket").detection_difference.mean()
log_ratio = query.groupby("xy_bucket").log_10_detection_ratio.mean()
n_cells = query.groupby("xy_bucket").x_centroid.count()

bin_image_z_score = np.zeros([n_bins, n_bins])
bin_image_difference = np.zeros([n_bins, n_bins])
bin_image_ratio = np.zeros([n_bins, n_bins])
bin_image_counts = np.zeros([n_bins, n_bins])

extent = [np.min(binx), np.max(binx), np.min(biny), np.max(biny)]
for coord in binx.index:
bin_image_z_score[coord[1], coord[0]] = z_score[coord]
bin_image_difference[coord[1], coord[0]] = difference[coord]
bin_image_ratio[coord[1], coord[0]] = log_ratio[coord]
bin_image_counts[coord[1], coord[0]] = n_cells[coord]

return (
extent,
bin_image_z_score,
bin_image_difference,
bin_image_ratio,
bin_image_counts,
)


def spatial_detection_scores(
reference: pd.DataFrame,
query: pd.DataFrame,
plot_stuff=True,
plot_stuff: bool = True,
query_name: str = "query data",
comparison_column="transcript_counts",
category="supercluster_name",
n_bins=50,
in_place=True,
non_spatial=False,
comparison_column: str = "transcript_counts",
category: str = "supercluster_name",
n_bins: int = 50,
in_place: bool = True,
non_spatial: bool = False,
use_kde: bool = False,
mask: float = 0.0,
):
"""
Calculate and plot spatial detection scores for query data compared to reference data.
Expand All @@ -71,22 +165,25 @@ def spatial_detection_scores(
n_bins (int, optional): The number of bins for spatial grouping. Defaults to 50.
in_place (bool, optional): Whether to modify the query data in place. Defaults to True.
non_spatial (bool, optional): Whether to compare to an ungrouped mean/std. Defaults to False.
use_kde (bool, optional): Whether to use kernel-density estimates instead of taking binned averages. Samples the KDE on a `n_bins` square grid for plotting, unless `n_bins == 0` (in which case it will be sampled at the cell coordinates). Defaults to False.
mask (float, optional): A quantile at which to create binary masks from. Defaults to 0.0 (do not binarize).
Returns:
dict: A dictionary containing the bin image, extent, query data, and reference data (if in_place is False).
"""
# code goes here

if category not in reference.columns or category not in query.columns:
raise ValueError("category " + category + " not in reference and query inputs")

shared_category_values = list(
set(reference[category].unique()) & set(query[category].unique())
)
if (
if in_place and (
len(shared_category_values) < query[category].unique().shape[0]
or len(shared_category_values) < reference[category].unique().shape[0]
):
print(
"Query and reference datasets had different shapes. Objects will not be modified in place"
)
in_place = False

if in_place:
Expand All @@ -107,55 +204,61 @@ def spatial_detection_scores(
s2["detection_relative_z_score"] = 0.0
s2["detection_difference"] = 0.0
s2["detection_ratio"] = 0.0
s2["log_10_detection_ratio"] = 0.0

for c, gb in s2.groupby(category, observed=True):
if c not in shared_category_values:
continue

s2.loc[s2[category] == c, ["detection_relative_z_score"]] = (
(s2.loc[s2[category] == c, [comparison_column]] - means[c]) / stds[c]
indices = s2[category] == c

s2.loc[indices, ["detection_relative_z_score"]] = (
(s2.loc[indices, [comparison_column]] - means[c]) / stds[c]
).values
s2.loc[s2[category] == c, ["detection_difference"]] = (
s2.loc[s2[category] == c, [comparison_column]] - means[c]
s2.loc[indices, ["detection_difference"]] = (
s2.loc[indices, [comparison_column]] - means[c]
).values
s2.loc[s2[category] == c, ["log_10_detection_ratio"]] = np.log10(
(s2.loc[s2[category] == c, [comparison_column]] / means[c]).values
s2.loc[indices, ["detection_ratio"]] = (
s2.loc[indices, [comparison_column]] / means[c]
).values
s2.loc[indices, ["log_10_detection_ratio"]] = np.log10(
s2.loc[indices, ["detection_ratio"]].values
)

s2["xy_bucket"] = list(
zip(
pd.cut(s2.x_centroid, n_bins, labels=list(range(n_bins))),
pd.cut(s2.y_centroid, n_bins, labels=list(range(n_bins))),
if not use_kde:
(
extent,
bin_image_z_score,
bin_image_difference,
bin_image_ratio,
bin_image_counts,
) = spatial_detection_score_binned(s2, n_bins)
else:
(
estimator,
extent,
bin_image_z_score,
bin_image_difference,
bin_image_ratio,
) = spatial_detection_score_kde(s2, n_bins)
# FIXME: not computing this
bin_image_counts = np.zeros(bin_image_ratio.shape)

if mask != 0.0:
bin_image_z_score = bin_image_z_score > np.quantile(bin_image_z_score, mask)
bin_image_difference = bin_image_difference > np.quantile(
bin_image_difference, mask
)
)

binx = s2.groupby("xy_bucket").x_centroid.mean()
biny = s2.groupby("xy_bucket").y_centroid.mean()

z_score = s2.groupby("xy_bucket").detection_relative_z_score.mean()
difference = s2.groupby("xy_bucket").detection_difference.mean()
log_ratio = s2.groupby("xy_bucket").log_10_detection_ratio.mean()
n_cells = s2.groupby("xy_bucket").x_centroid.count()

bin_image_z_score = np.zeros([n_bins, n_bins])
bin_image_difference = np.zeros([n_bins, n_bins])
bin_image_ratio = np.zeros([n_bins, n_bins])
bin_image_counts = np.zeros([n_bins, n_bins])

extent = [np.min(binx), np.max(binx), np.min(biny), np.max(biny)]
for coord in binx.index:
bin_image_z_score[coord[1], coord[0]] = z_score[coord]
bin_image_difference[coord[1], coord[0]] = difference[coord]
bin_image_ratio[coord[1], coord[0]] = log_ratio[coord]
bin_image_counts[coord[1], coord[0]] = n_cells[coord]
bin_image_ratio = bin_image_ratio > np.quantile(bin_image_ratio, mask)
bin_image_counts = bin_image_counts > np.quantile(bin_image_counts, mask)

if plot_stuff:
if non_spatial:
title_string = "Non-spatial Detection Scores"
else:
title_string = "Spatial Detection Scores"
min_maxes = {
"detection z-score": [bin_image_z_score, [-1, 1]],
"detection z-score": [bin_image_z_score, [-2, 2]],
"total counts difference": [bin_image_difference, [-100, 100]],
"log10(detection ratio)": [bin_image_ratio, [-1, 1]],
}
Expand All @@ -174,38 +277,49 @@ def spatial_detection_scores(
)
for ii, plot_name in enumerate(min_maxes.keys()):
ax = axs[ii]
pcm = ax.imshow(
min_maxes[plot_name][0],
extent=extent,
cmap="coolwarm_r",
vmin=min_maxes[plot_name][1][0],
vmax=min_maxes[plot_name][1][1],
)

if mask != 0.0:
cmap = "Greys"
else:
cmap = "coolwarm_r"

if n_bins != 0:
pcm = ax.imshow(
min_maxes[plot_name][0],
extent=extent,
cmap=cmap,
vmin=min_maxes[plot_name][1][0],
vmax=min_maxes[plot_name][1][1],
)
else:
pcm = ax.scatter(
s2.x_centroid.values,
-s2.y_centroid.values,
c=min_maxes[plot_name][0],
cmap=cmap,
)
fig.colorbar(pcm, ax=ax, shrink=0.7)
ax.set_title(query_name + "\n" + plot_name)

if in_place:
ret = dict(
z_score_image=bin_image_z_score,
difference_image=bin_image_difference,
ratio_image=bin_image_ratio,
extent=extent,
count_image=bin_image_counts,
query=True,
reference=True,
)

return dict(
z_score_image=bin_image_z_score,
difference_image=bin_image_difference,
ratio_image=bin_image_ratio,
extent=extent,
count_image=bin_image_counts,
query=True,
reference=True,
)
if use_kde:
ret["z_score_estimator"] = estimator

if in_place:
return ret
else:
return dict(
z_score_image=bin_image_z_score,
difference_image=bin_image_difference,
ratio_image=bin_image_ratio,
extent=extent,
count_image=bin_image_counts,
query=s2,
reference=s1,
)
ret["query"] = s2
ret["reference"] = s1
return ret


def summarize_and_plot(
Expand Down

0 comments on commit edfe680

Please sign in to comment.