Skip to content

Commit

Permalink
Merge pull request #35 from FelipePCarcanholo/main
Browse files Browse the repository at this point in the history
Add automatic spot size for spatial plot, select genes for visualization and other visualization updates
  • Loading branch information
berl authored Dec 12, 2024
2 parents 4062996 + cb47609 commit 1cc4391
Showing 1 changed file with 88 additions and 21 deletions.
109 changes: 88 additions & 21 deletions spatial_compare/spatial_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


DEFAULT_DATA_NAMES = ["Data 0", "Data 1"]
TARGET_LEGEND_MARKER_SIZE = 20


class SpatialCompare:
Expand Down Expand Up @@ -58,15 +59,18 @@ class SpatialCompare:
-------
set_category(category)
Set the category to compare.
spatial_plot(plot_legend=True, min_cells_to_plot=10, decimate_for_spatial_plot=1, figsize=[20,10], category_values=[])
spatial_plot(plot_legend=True, min_cells_to_plot=10, decimate_for_spatial_plot=1, figsize=[20,10], category_values=[], dot_size=3)
Plot the spatial data for the two datasets.
de_novo_cluster(plot_stuff=False, correspondence_level="leiden_1",rerun_preprocessing=False)
Perform de novo clustering on the two datasets.
find_matched_groups(n_top_groups=100, n_shared_groups=30, min_n_cells=100, category_values=[], exclude_group_string="zzzzzzzzzzzzzzz", plot_stuff=False, figsize=[10,10])
Find matched groups between the two datasets.
compare_expression(category_values=[], plot_stuff=False, min_mean_expression=.2, min_genes_to_compare=5, min_cells=10)
compare_expression(category_values=[], plot_stuff=False, min_mean_expression=.2, min_genes_to_compare=5, min_cells=10, ntop_genes=20)
Compare gene expression between the two datasets.
run_and_plot(category_values = d1d2_cells, min_mean_expression=.2, ntop_genes=20, filtred=True, dot_size=)
Run all the plots, can select the genes to appear the label (ntop_genes), choose to filter 25 bottom, middle and top genes in the boxplot (filtred=True). Can choose the size of dots of spatial plot (dot_size=(3*18231)/(self.ad_0.n_obs)).
"""

def __init__(
Expand Down Expand Up @@ -178,6 +182,7 @@ def spatial_plot(
decimate_for_spatial_plot=1,
figsize=[20, 10],
category_values=[],
dot_size=None, # Add a parameter for dot size
):

plt.figure(figsize=figsize)
Expand All @@ -187,8 +192,16 @@ def spatial_plot(
if len(category_values) == 0:
category_values = all_category_values

if dot_size is None:
ad0_dot_size = (3 * 18231) / (self.ad_0.n_obs)
ad1_dot_size = (3 * 18231) / (self.ad_1.n_obs)
else:
ad0_dot_size = dot_size
ad1_dot_size = dot_size

for c in category_values:
plt.subplot(1, 2, 1)

plt.title(self.data_names[0])
if np.sum(self.ad_0.obs[self.category] == c) > min_cells_to_plot:
label = c + ": " + str(np.sum(self.ad_0.obs[self.category] == c))
Expand All @@ -204,11 +217,12 @@ def spatial_plot(
],
".",
label=label,
markersize=0.5,
markersize=ad0_dot_size, # Use the dot_size parameter
)
plt.axis("equal")
if plot_legend:
plt.legend(markerscale=5)
markerscale = TARGET_LEGEND_MARKER_SIZE / ad0_dot_size
plt.legend(markerscale=markerscale)
plt.subplot(1, 2, 2)
plt.title(self.data_names[1])
if np.sum(self.ad_1.obs[self.category] == c) > min_cells_to_plot:
Expand All @@ -224,11 +238,12 @@ def spatial_plot(
],
".",
label=label,
markersize=0.5,
markersize=ad1_dot_size, # Use the dot_size parameter
)
plt.axis("equal")
if plot_legend:
plt.legend(markerscale=5)
markerscale = TARGET_LEGEND_MARKER_SIZE / ad1_dot_size
plt.legend(markerscale=markerscale)

def de_novo_cluster(
self, plot_stuff=False, correspondence_level="leiden_1", run_preprocessing=False
Expand Down Expand Up @@ -377,15 +392,17 @@ def compare_expression(
min_mean_expression=0.2,
min_genes_to_compare=5,
min_cells=10,
ntop_genes=20,
):
# group cells
# Group cells
if len(category_values) == 0:
raise ValueError(
"please supply a list of values for the category " + self.category
)

category_records = []
gene_ratio_dfs = {}

for category_value in category_values:
group_mask_0 = self.ad_0.obs[self.category] == category_value
group_mask_1 = self.ad_1.obs[self.category] == category_value
Expand Down Expand Up @@ -417,17 +434,22 @@ def compare_expression(
axis=0,
)
).flatten()

# Filter genes above minimum mean expression
means_0_gt_min = np.nonzero(means_0 > min_mean_expression)[0]
means_1_gt_min = np.nonzero(means_1 > min_mean_expression)[0]

above_means0 = self.ad_0.var[
self.ad_0.var.index.isin(self.shared_genes)
].iloc[means_0_gt_min]
above_means1 = self.ad_1.var[
self.ad_1.var.index.isin(self.shared_genes)
].iloc[means_1_gt_min]

shared_above_mean = [
g for g in above_means1.index if g in above_means0.index
]

if len(shared_above_mean) < min_genes_to_compare:
print(
self.category
Expand All @@ -440,13 +462,24 @@ def compare_expression(
)
continue

# Calculate means again after filtering
means_0 = np.array(
np.mean(self.ad_0[group_mask_0, shared_above_mean].X, axis=0)
).flatten()
means_1 = np.array(
np.mean(self.ad_1[group_mask_1, shared_above_mean].X, axis=0)
).flatten()

# Calculate average counts for selecting top genes
average_counts = (means_0 + means_1) / 2

# Get indices of the top 20 genes based on average counts for this subclass
top_indices = np.argsort(average_counts)[
-ntop_genes:
] # Get indices of top 10 genes

shared_genes = shared_above_mean

p_coef = np.polynomial.Polynomial.fit(means_0, means_1, 1).convert().coef
category_records.append(
{
Expand Down Expand Up @@ -478,6 +511,7 @@ def compare_expression(
+ " mean ratio: "
+ str(category_records[-1]["mean_ratio"])[:4]
)

low_expression = np.logical_and(means_0 < 1.0, means_1 < 1.0)
plt.loglog(
means_0[low_expression],
Expand All @@ -494,39 +528,49 @@ def compare_expression(
plt.xlabel(self.data_names[0] + ", N = " + str(np.sum(group_mask_0)))
plt.ylabel(self.data_names[1] + ", N = " + str(np.sum(group_mask_1)))

for g in shared_genes:
if (
means_0[np.nonzero(np.array(shared_genes) == g)] == 0
or means_1[np.nonzero(np.array(shared_genes) == g)] == 0
) or low_expression[np.array(shared_genes) == g]:
# Add labels only for the top 20 genes based on average counts for this subclass
for idx in top_indices:
g = shared_genes[idx] if idx < len(shared_genes) else None

if g is None or (
means_0[idx] == 0 or means_1[idx] == 0 or low_expression[idx]
):
continue

plt.text(
means_0[np.nonzero(np.array(shared_genes) == g)],
means_1[np.nonzero(np.array(shared_genes) == g)],
means_0[idx],
means_1[idx],
g,
fontsize=10,
)

plt.plot(
[np.min(means_0), np.max(means_0)],
[np.min(means_0), np.max(means_0)],
"--",
)

print(gene_ratio_dfs.keys())
if len(gene_ratio_dfs.keys()) > 0:

gene_ratio_df = pd.concat(gene_ratio_dfs, axis=1)
else:
gene_ratio_df = None

return {
"data_names": self.data_names,
"category_results": pd.DataFrame.from_records(category_records),
"gene_ratio_dataframe": gene_ratio_df,
}

def plot_detection_ratio(self, gene_ratio_dataframe, figsize=[15, 15]):
def plot_detection_ratio(
self, gene_ratio_dataframe, figsize=[15, 15], filtred=True
):

detection_ratio_plots(
gene_ratio_dataframe, data_names=self.data_names, figsize=figsize
gene_ratio_dataframe,
data_names=self.data_names,
figsize=figsize,
filtred=filtred,
)

def spatial_compare(self, **kwargs):
Expand Down Expand Up @@ -565,12 +609,16 @@ def spatial_compare(self, **kwargs):
def run_and_plot(self, **kwargs):
if "category" in kwargs.keys():
self.set_category(kwargs["category"])
dot_size = kwargs.get("dot_size", None)
ntop_genes = kwargs.get("ntop_genes", 20)
filtred = kwargs.get("filtred", True)

self.spatial_plot()
self.spatial_plot(dot_size=dot_size)
self.spatial_compare_results = self.spatial_compare(plot_stuff=True, **kwargs)
self.plot_detection_ratio(
self.spatial_compare_results["expression_results"]["gene_ratio_dataframe"],
figsize=[30, 20],
filtred=filtred,
)
return True

Expand Down Expand Up @@ -926,25 +974,44 @@ def filter_and_cluster_twice(


def detection_ratio_plots(
gene_ratio_df, data_names=DEFAULT_DATA_NAMES, figsize=[15, 15]
gene_ratio_df,
data_names=DEFAULT_DATA_NAMES,
figsize=[15, 15],
filtred=True,
):

sorted_genes = [
str(s) for s in gene_ratio_df.mean(axis=1).sort_values().index.values
]
# Select top 25, bottom 25 and middle 25
top_25 = sorted_genes[-25:] # Top 25 highest
bottom_25 = sorted_genes[:25] # Bottom 25 lowest
middle_index = len(sorted_genes) // 2
middle_25 = sorted_genes[middle_index - 12 : middle_index + 13] # Middle 25

# Combine selected ratios for plotting
selected_ratios = bottom_25 + middle_25 + top_25
if filtred:
genes_boxplot = selected_ratios
else:
genes_boxplot = sorted_genes

plt.figure(figsize=figsize)
plt.subplot(3, 1, 1)
p = sns.boxplot(
gene_ratio_df.loc[sorted_genes, :].T,
gene_ratio_df.loc[genes_boxplot, :].T,
)
p.set_yscale("log")
p.set_xlabel("gene", fontsize=20)
p.set_ylabel(
"detection ratio\n" + data_names[1] + " / " + data_names[0], fontsize=20
)
ax = plt.gca()
ax.tick_params(axis="x", labelrotation=45, labelsize=10)
if filtred:
ax.tick_params(axis="x", labelrotation=45, labelsize=18)
else:
ax.tick_params(axis="x", labelrotation=45, labelsize=10)

ax.tick_params(axis="y", labelsize=20, which="major")
ax.tick_params(axis="y", labelsize=10, which="minor")

Expand Down

0 comments on commit 1cc4391

Please sign in to comment.