Skip to content


re-adding missing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
egelfan2 committed Dec 10, 2024
1 parent 5ae79a2 commit 2cf4568
Showing 1 changed file with 349 additions and 3 deletions.
352 changes: 349 additions & 3 deletions spatial_compare/
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,11 @@ def collect_mutual_match_and_doublets(
save: whether to save the intermediate and final results. Bool, default true
nn_dist: distance cutoff for identifying likely same cell between segmentations. float, default 2.5 (um)
min_transcripts: minimum number of transcripts needed to define a cell too low quality to be considered for mapping
savepath: path to which you'd like to save your results. Required if using anndata objects and save = True, otherwise defaults to seg_b_path
savepath: path to which you'd like to save your results.
seg_comp_df: dataframe with unique index describing cell spatial locations (x and y), segmentation identifier, low quality cell identifier, mutual matches, and putative doublets.

if not savepath:
savepath = seg_b_path
# grab base comparison df
seg_comp_df = get_segmentation_data(
Expand Down Expand Up @@ -1105,3 +1103,351 @@ def get_column_ordering(df, ordered_rows):
output = []
[output.extend(empty_columns[k]) for k in empty_columns]
return output

def create_seg_comp_df(barcode, seg_name, base_path, min_transcripts):
Gathers related segmentation results to create descriptive dataframe
barcode: unique identifier, for locating specific segmentation path. String.
seg name: descriptor of segmentation, i.e. algo name. String.
base path: path to segmentation results
min transcripts: minimum number of transcripts needed to define a cell too low quality to be considered for mapping
seg_df: dataframe with segmentation cell xy locations, low quality cell identifier, and source segmentation ID. index is cell IDs + segmentation identifier string
seg_path = Path(base_path).joinpath(barcode)
cell_check = [x for x in seg_path.glob("*.csv") if "cellpose" in x.stem]
if cell_check:
cxg = pd.read_table(
str(seg_path) + "/cellpose-cell-by-gene.csv", index_col=0, sep=","
metadata = pd.read_table(
str(seg_path) + "/cellpose_metadata.csv", index_col=0, sep=","
cxg = pd.read_table(str(seg_path) + "/cell_by_gene.csv", index_col=0, sep=",")
meta = ad.read_h5ad(str(seg_path) + "/metadata.h5ad")
metadata = meta.obs

# assemble seg df with filt cells col, xy cols, and seg name
high_quality_cells = (
cxg.index[np.where((transcripts_per_cell(cxg) >= min_transcripts))[0]]
seg_df = metadata[["center_x", "center_y"]].copy()
seg_df.index = seg_df.index.astype(str)
seg_df.loc[:, "source"] = seg_name
seg_df.loc[:, "low_quality_cells"] = np.where(
seg_df.index.isin(high_quality_cells), False, True
seg_df.index = seg_name + "_" + seg_df.index
return seg_df

def get_segmentation_data(
Loads segmentation data from anndata object
bc: unique section identifier, string
seg_name_a, seg_name_b: names of segmentations to be compared, string
anndata_path_a, anndata_path_b: path to segmented anndata objects, string
save: whether to save the intermediary dataframe (useful if running functions individually), bool
savepath: path to which data should be saved. Defaults to seg_b_path if not specified.
min_transcripts: minimum number of transcripts needed to define a cell too low quality to be considered for mapping
seg_comp_df: dataframe describing cell spatial locations (x and y), segmentation identifier, low quality cell identifier, with unique index
savepath = (
savepath + bc + "_seg_comp_df_" + seg_name_a + "_and_" + seg_name_b + ".csv"
if Path(savepath).exists() and reuse_saved:
seg_comp_df = pd.read_csv(savepath, index_col=0)
seg_dfs = []
for i in range(2):
if i == 0:
seg_h5ad = anndata_a
name = seg_name_a
seg_h5ad = anndata_b
name = seg_name_b
df_cols = ["center_x", "center_y", "source"]
# sum across .X to get sum transcripts/cell > 40
if "low_quality_cells" not in seg_h5ad.obs.columns.tolist():
high_quality_cells = [
for idx, x in enumerate(seg_h5ad.X.sum(axis=1))
if x >= min_transcripts
seg_h5ad.obs["low_quality_cells"] = [True] * len(seg_h5ad.obs)
seg_h5ad.obs.iloc[high_quality_cells, -1] = False
seg_h5ad.obs.loc[:, "source"] = name
# save only necessary columns
if "center_x" not in seg_h5ad.obs.columns.tolist():
seg_h5ad.obs["center_x"] = seg_h5ad.obsm["spatial"][:, 0]
seg_h5ad.obs["center_y"] = seg_h5ad.obsm["spatial"][:, 1]
seg_df = seg_h5ad.obs[df_cols]
seg_comp_df = pd.concat(seg_dfs, axis=0)
seg_comp_df.index = seg_comp_df.index.astype(str)
if save:
return seg_comp_df

def get_mutual_matches(dfa, dfb, nn_dist):
Compare x and y locations using nearest neighbor function of each segmentation result to identify same cells across segmentations
dfa, dfb: subset dataframe for each segmentation
nn_dist: distance cutoff for identifying likely same cell between segmentations. float.
mutual_match: Dictionary of cell IDs to matched cell ids. Keys are stacked source a and source b cell ids, with values being match from other source (so source a cell id keys have source b cell id values, and vice versa)

# get single nearest neighbor for each
tree_a = cKDTree(dfa[["center_x", "center_y"]].copy())
dists_a, inds_a = tree_a.query(
dfb[["center_x", "center_y"]].copy(), 1
) # gives me index locs for dfa with neighbor dfb
tree_b = cKDTree(dfb[["center_x", "center_y"]].copy())
dists_b, inds_b = tree_b.query(
dfa[["center_x", "center_y"]].copy(), 1
) # vice versa
# get mutually matching pairs and save (index of seg_comp_df is str)
match_to_dfb = pd.DataFrame(
index=pd.Index(dfb.index, name="match_dfb_index"),
match_to_dfb["same_cell"] = np.where(dists_a <= nn_dist, True, False)
match_to_dfa = pd.DataFrame(
index=pd.Index(dfb.iloc[inds_b].index.values.tolist(), name="match_dfb_index"),
match_to_dfa["same_cell"] = np.where(dists_b <= nn_dist, True, False)
mutual_matches = pd.merge(match_to_dfa, match_to_dfb, how="outer")
mutual_matches = mutual_matches.set_index("match_dfa_index").join(
mutual_match_dict = (
mutual_matches[mutual_matches["same_cell"] == True]
.drop(["same_cell", "same_cell_match"], axis=1)
inv_mutual_match_dict = {
v: k for k, v in mutual_match_dict["match_dfb_index"].items()
# added union operwtor to dictionaries in 2020 for 3.9+ (pep 584), discussed
mutual_matches_stacked = (
mutual_match_dict["match_dfb_index"] | inv_mutual_match_dict
return mutual_matches_stacked

def create_node_df_sankey(
Dataframe describing each node to be included in sankey diagram. Additionally contains data helpful for generating required links dataframe.
seg_comp_df: dataframe with unique index describing cell spatial locations (x and y), segmentation identifier, low quality cell identifier, mutual matches, and putative doublets.
barcode: unique identifier for section, used for creating the save name for the dataframe, string.
save: whether to save the final results. Bool, default true
savepath: path to which results should be saved, string
nodes_df: dataframe containing node label, color, and value for sankey diagram creation, with additional level and source keys for ease of link dataframe creation

color_dict = {
seg_comp_df.source.unique().tolist()[0]: "red",
seg_comp_df.source.unique().tolist()[1]: "blue",
nodes_df = pd.DataFrame(columns=["Label", "Color", "Level", "Source", "Value"])
unknown_unmatched_cells = {}
for source, g in seg_comp_df.groupby("source"):
new_rows = []
low_q_and_match = g.groupby("low_quality_cells").agg(
) # gives me low q t/f counts, and matched.
total_cells = low_q_and_match.iloc[:, 0].sum()
nodes_df.loc[len(nodes_df)] = {
"Label": "total <br>" + str(total_cells),
"Color": color_dict[source],
"Level": 0,
"Source": source,
"Value": total_cells,
# low and normal quality cells
nodes_df.loc[len(nodes_df)] = {
"Label": "low quality cells <br>" + str(low_q_and_match.iloc[1, 1]),
"Color": color_dict[source],
"Level": 1,
"Source": source,
"Value": low_q_and_match.iloc[1, 1],
nodes_df.loc[len(nodes_df)] = {
"Label": "normal quality cells <br>" + str(low_q_and_match.iloc[0, 1]),
"Color": color_dict[source],
"Level": 1,
"Source": source,
"Value": low_q_and_match.iloc[0, 1],
# matched and unmatched cells
# because row 1 in agg is false only! so false x low (normal) and false x match (unmatch!)
matched_cells = low_q_and_match.iloc[0, 1] - low_q_and_match.iloc[0, 3]
nodes_df.loc[len(nodes_df)] = {
"Label": "matched cells <br>" + str(matched_cells),
"Color": color_dict[source],
"Level": 2,
"Source": source,
"Value": matched_cells,
nodes_df.loc[len(nodes_df)] = {
"Label": "unmatched cells <br>" + str(low_q_and_match.iloc[0, 3]),
"Color": color_dict[source],
"Level": 2,
"Source": source,
"Value": low_q_and_match.iloc[0, 3],
# raise flag if too few cells matched, may indicate scaling issue
if matched_cells <= (0.1 * len(g)):
raise ValueError(
"The number of matched cells is less than 1% of total cells present. Please check your inputs for potential scaling issues."

rem_col_names = [
x for x in g.columns[5:] if source + "_unfilt" not in x
] # get remaining columns
# pulling this to access unknown unmatched cells easily later
high_q_df = g[g[g.columns.tolist()[3]] == False]
unmatched_df = high_q_df[high_q_df[g.columns.tolist()[4]].isna() == False]
rem_col_df = unmatched_df.loc[:, rem_col_names]
rem_col_counts = rem_col_df.isna().value_counts().reset_index()
name = " ".join(rem_col_names[0].split("_")) + "<br>"
fcounts = rem_col_counts[rem_col_counts[rem_col_names[0]] == False][
nodes_df.loc[len(nodes_df)] = {
"Label": name + str(fcounts),
"Color": color_dict[source],
"Level": 3,
"Source": source,
"Value": fcounts,
# remaining unmatched cells (known or unknown)
if len(rem_col_counts) > 1:
rem_unmatched_cells = rem_col_counts[
rem_col_counts[rem_col_names[0]] == True
unknown_unmatched_cells[source] = rem_col_df[
rem_col_df[rem_col_names[0]].isna() == True
name = " ".join(rem_col_names[-1].split("_"))
nodes_df.loc[len(nodes_df)] = {
"Label": name + str(rem_unmatched_cells),
"Color": color_dict[source],
"Level": 3,
"Source": source,
"Value": rem_unmatched_cells,
if save:
nodes_df.to_csv(savepath + "/sankey_nodes_df" + barcode + ".csv")
return nodes_df, unknown_unmatched_cells

def create_link_df_sankey(
Generates links dataframe from nodes dataframe, connecting relevant nodes to one another and ensuring distinct segmentations are separated.
nodes_df: dataframe containing information on each node for sankey diagram.
barcode: unique identifier for section, used for creating the save name for the dataframe, string.
save: whether to save the final results. Bool, default true
savepath: path to which results should be saved, string
links_df: dataframe containing source, target, value, and colors for sankey diagram creation

link_color_dict = {"red": "rgb(205, 209, 228)", "blue": "rgb(205, 209, 228)"}
links_df = pd.DataFrame(columns=["Source", "Target", "Value", "Link Color"])
for i, row in nodes_df.iterrows():
if "total" in row.Label:
continue # skip 0 since all comes from zero
target = i
value = row["Value"]
source = nodes_df[
(nodes_df["Level"] == row["Level"] - 1)
& (nodes_df["Source"] == row["Source"])
if len(source) > 1 and row["Level"] != 3:
source = [
for i, x in nodes_df.iloc[source, :].iterrows()
if "un" not in x["Label"]
if "low" not in x["Label"]
elif len(source) > 1 and row["Level"] == 3:
source = [
for i, x in nodes_df.iloc[source, :].iterrows()
if "un" in x["Label"]
source = source[0]
prev_row = nodes_df.iloc[source, :]
link_color = link_color_dict[prev_row["Color"]]
links_df.loc[len(links_df)] = {
"Source": source,
"Target": target,
"Value": value,
"Link Color": link_color,

# special case: matched cells should have same name, new source, new color
matched_cell_rows = nodes_df[
(nodes_df["Label"].str.contains("matched cells"))
& (~nodes_df["Label"].str.contains("un"))
matched_cell_val = matched_cell_rows["Value"].values.tolist()[0]
nodes_df.loc[len(nodes_df)] = {
"Label": "matched cells <br> " + str(matched_cell_val),
"Color": "violet",
"Level": matched_cell_rows["Level"].values.tolist()[0],
"Source": "Both",
"Value": matched_cell_val,
# find matched col for both in nodes_df, then find rows with those IDS in links df and connect them
update_idxs = links_df[
links_df.loc[update_idxs, "Target"] = [
nodes_df[nodes_df["Source"] == "Both"].index.values
] * len(update_idxs)
if save:
links_df.to_csv(savepath + "/sankey_links_df.csv")
return links_df

0 comments on commit 2cf4568

Please sign in to comment.