Skip to content

Commit

Permalink
re-adding missing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
egelfan2 committed Dec 10, 2024
1 parent 5ae79a2 commit 2cf4568
Showing 1 changed file with 349 additions and 3 deletions.
352 changes: 349 additions & 3 deletions spatial_compare/spatial_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,11 @@ def collect_mutual_match_and_doublets(
save: whether to save the intermediate and final results. Bool, default true
nn_dist: distance cutoff for identifying likely same cell between segmentations. float, default 2.5 (um)
min_transcripts: minimum number of transcripts needed to define a cell too low quality to be considered for mapping
savepath: path to which you'd like to save your results. Required if using anndata objects and save = True, otherwise defaults to seg_b_path
savepath: path to which you'd like to save your results.
OUTPUTS:
seg_comp_df: dataframe with unique index describing cell spatial locations (x and y), segmentation identifier, low quality cell identifier, mutual matches, and putative doublets.
"""

if not savepath:
savepath = seg_b_path
# grab base comparison df
seg_comp_df = get_segmentation_data(
bc,
Expand Down Expand Up @@ -1105,3 +1103,351 @@ def get_column_ordering(df, ordered_rows):
output = []
[output.extend(empty_columns[k]) for k in empty_columns]
return output


def create_seg_comp_df(barcode, seg_name, base_path, min_transcripts):
"""
Gathers related segmentation results to create descriptive dataframe
INPUTS
barcode: unique identifier, for locating specific segmentation path. String.
seg name: descriptor of segmentation, i.e. algo name. String.
base path: path to segmentation results
min transcripts: minimum number of transcripts needed to define a cell too low quality to be considered for mapping
OUTPUTS
seg_df: dataframe with segmentation cell xy locations, low quality cell identifier, and source segmentation ID. index is cell IDs + segmentation identifier string
"""
seg_path = Path(base_path).joinpath(barcode)
cell_check = [x for x in seg_path.glob("*.csv") if "cellpose" in x.stem]
if cell_check:
cxg = pd.read_table(
str(seg_path) + "/cellpose-cell-by-gene.csv", index_col=0, sep=","
)
metadata = pd.read_table(
str(seg_path) + "/cellpose_metadata.csv", index_col=0, sep=","
)
else:
cxg = pd.read_table(str(seg_path) + "/cell_by_gene.csv", index_col=0, sep=",")
meta = ad.read_h5ad(str(seg_path) + "/metadata.h5ad")
metadata = meta.obs

# assemble seg df with filt cells col, xy cols, and seg name
high_quality_cells = (
cxg.index[np.where((transcripts_per_cell(cxg) >= min_transcripts))[0]]
.astype(str)
.values.tolist()
)
seg_df = metadata[["center_x", "center_y"]].copy()
seg_df.index = seg_df.index.astype(str)
seg_df.loc[:, "source"] = seg_name
seg_df.loc[:, "low_quality_cells"] = np.where(
seg_df.index.isin(high_quality_cells), False, True
)
seg_df.index = seg_name + "_" + seg_df.index
return seg_df


def get_segmentation_data(
bc,
anndata_a,
anndata_b,
seg_name_a,
seg_name_b,
save,
savepath,
reuse_saved,
min_transcripts,
):
"""
Loads segmentation data from anndata object
Inputs:
bc: unique section identifier, string
seg_name_a, seg_name_b: names of segmentations to be compared, string
anndata_path_a, anndata_path_b: path to segmented anndata objects, string
save: whether to save the intermediary dataframe (useful if running functions individually), bool
savepath: path to which data should be saved. Defaults to seg_b_path if not specified.
min_transcripts: minimum number of transcripts needed to define a cell too low quality to be considered for mapping
RETURNS
seg_comp_df: dataframe describing cell spatial locations (x and y), segmentation identifier, low quality cell identifier, with unique index
"""
savepath = (
savepath + bc + "_seg_comp_df_" + seg_name_a + "_and_" + seg_name_b + ".csv"
)
if Path(savepath).exists() and reuse_saved:
seg_comp_df = pd.read_csv(savepath, index_col=0)
else:
seg_dfs = []
for i in range(2):
if i == 0:
seg_h5ad = anndata_a
name = seg_name_a
else:
seg_h5ad = anndata_b
name = seg_name_b
df_cols = ["center_x", "center_y", "source"]
# sum across .X to get sum transcripts/cell > 40
if "low_quality_cells" not in seg_h5ad.obs.columns.tolist():
high_quality_cells = [
idx
for idx, x in enumerate(seg_h5ad.X.sum(axis=1))
if x >= min_transcripts
]
seg_h5ad.obs["low_quality_cells"] = [True] * len(seg_h5ad.obs)
seg_h5ad.obs.iloc[high_quality_cells, -1] = False
df_cols.append("low_quality_cells")
seg_h5ad.obs.loc[:, "source"] = name
# save only necessary columns
if "center_x" not in seg_h5ad.obs.columns.tolist():
seg_h5ad.obs["center_x"] = seg_h5ad.obsm["spatial"][:, 0]
seg_h5ad.obs["center_y"] = seg_h5ad.obsm["spatial"][:, 1]
seg_df = seg_h5ad.obs[df_cols]
seg_dfs.append(seg_df)
seg_comp_df = pd.concat(seg_dfs, axis=0)
seg_comp_df.index = seg_comp_df.index.astype(str)
if save:
seg_comp_df.to_csv(savepath)
return seg_comp_df


def get_mutual_matches(dfa, dfb, nn_dist):
"""
Compare x and y locations using nearest neighbor function of each segmentation result to identify same cells across segmentations
INPUTS
dfa, dfb: subset dataframe for each segmentation
nn_dist: distance cutoff for identifying likely same cell between segmentations. float.
OUTPUT:
mutual_match: Dictionary of cell IDs to matched cell ids. Keys are stacked source a and source b cell ids, with values being match from other source (so source a cell id keys have source b cell id values, and vice versa)
"""

# get single nearest neighbor for each
tree_a = cKDTree(dfa[["center_x", "center_y"]].copy())
dists_a, inds_a = tree_a.query(
dfb[["center_x", "center_y"]].copy(), 1
) # gives me index locs for dfa with neighbor dfb
tree_b = cKDTree(dfb[["center_x", "center_y"]].copy())
dists_b, inds_b = tree_b.query(
dfa[["center_x", "center_y"]].copy(), 1
) # vice versa
# get mutually matching pairs and save (index of seg_comp_df is str)
match_to_dfb = pd.DataFrame(
data=dfa.iloc[inds_a].index.values.tolist(),
index=pd.Index(dfb.index, name="match_dfb_index"),
columns=["match_dfa_index"],
)
match_to_dfb["same_cell"] = np.where(dists_a <= nn_dist, True, False)
match_to_dfa = pd.DataFrame(
data=dfa.index,
index=pd.Index(dfb.iloc[inds_b].index.values.tolist(), name="match_dfb_index"),
columns=["match_dfa_index"],
)
match_to_dfa["same_cell"] = np.where(dists_b <= nn_dist, True, False)
mutual_matches = pd.merge(match_to_dfa, match_to_dfb, how="outer")
mutual_matches = mutual_matches.set_index("match_dfa_index").join(
match_to_dfa.reset_index().set_index("match_dfa_index"),
how="left",
rsuffix="_match",
)
mutual_match_dict = (
mutual_matches[mutual_matches["same_cell"] == True]
.drop(["same_cell", "same_cell_match"], axis=1)
.to_dict()
)
inv_mutual_match_dict = {
v: k for k, v in mutual_match_dict["match_dfb_index"].items()
}
# added union operwtor to dictionaries in 2020 for 3.9+ (pep 584), discussed https://stackoverflow.com/questions/38987/how-do-i-merge-two-dictionaries-in-a-single-expression-in-python
mutual_matches_stacked = (
mutual_match_dict["match_dfb_index"] | inv_mutual_match_dict
)
return mutual_matches_stacked


def create_node_df_sankey(
seg_comp_df,
barcode,
save=True,
savepath="/allen/programs/celltypes/workgroups/hct/emilyg/reseg_project/new_seg/",
):
"""
Dataframe describing each node to be included in sankey diagram. Additionally contains data helpful for generating required links dataframe.
INPUTS
seg_comp_df: dataframe with unique index describing cell spatial locations (x and y), segmentation identifier, low quality cell identifier, mutual matches, and putative doublets.
barcode: unique identifier for section, used for creating the save name for the dataframe, string.
save: whether to save the final results. Bool, default true
savepath: path to which results should be saved, string
OUTPUTS
nodes_df: dataframe containing node label, color, and value for sankey diagram creation, with additional level and source keys for ease of link dataframe creation
"""

color_dict = {
seg_comp_df.source.unique().tolist()[0]: "red",
seg_comp_df.source.unique().tolist()[1]: "blue",
}
nodes_df = pd.DataFrame(columns=["Label", "Color", "Level", "Source", "Value"])
unknown_unmatched_cells = {}
for source, g in seg_comp_df.groupby("source"):
new_rows = []
low_q_and_match = g.groupby("low_quality_cells").agg(
"count"
) # gives me low q t/f counts, and matched.
total_cells = low_q_and_match.iloc[:, 0].sum()
nodes_df.loc[len(nodes_df)] = {
"Label": "total <br>" + str(total_cells),
"Color": color_dict[source],
"Level": 0,
"Source": source,
"Value": total_cells,
}
# low and normal quality cells
nodes_df.loc[len(nodes_df)] = {
"Label": "low quality cells <br>" + str(low_q_and_match.iloc[1, 1]),
"Color": color_dict[source],
"Level": 1,
"Source": source,
"Value": low_q_and_match.iloc[1, 1],
}
nodes_df.loc[len(nodes_df)] = {
"Label": "normal quality cells <br>" + str(low_q_and_match.iloc[0, 1]),
"Color": color_dict[source],
"Level": 1,
"Source": source,
"Value": low_q_and_match.iloc[0, 1],
}
# matched and unmatched cells
# because row 1 in agg is false only! so false x low (normal) and false x match (unmatch!)
matched_cells = low_q_and_match.iloc[0, 1] - low_q_and_match.iloc[0, 3]
nodes_df.loc[len(nodes_df)] = {
"Label": "matched cells <br>" + str(matched_cells),
"Color": color_dict[source],
"Level": 2,
"Source": source,
"Value": matched_cells,
}
nodes_df.loc[len(nodes_df)] = {
"Label": "unmatched cells <br>" + str(low_q_and_match.iloc[0, 3]),
"Color": color_dict[source],
"Level": 2,
"Source": source,
"Value": low_q_and_match.iloc[0, 3],
}
# raise flag if too few cells matched, may indicate scaling issue
if matched_cells <= (0.1 * len(g)):
raise ValueError(
"The number of matched cells is less than 1% of total cells present. Please check your inputs for potential scaling issues."
)

rem_col_names = [
x for x in g.columns[5:] if source + "_unfilt" not in x
] # get remaining columns
# pulling this to access unknown unmatched cells easily later
high_q_df = g[g[g.columns.tolist()[3]] == False]
unmatched_df = high_q_df[high_q_df[g.columns.tolist()[4]].isna() == False]
rem_col_df = unmatched_df.loc[:, rem_col_names]
rem_col_counts = rem_col_df.isna().value_counts().reset_index()
name = " ".join(rem_col_names[0].split("_")) + "<br>"
fcounts = rem_col_counts[rem_col_counts[rem_col_names[0]] == False][
"count"
].values.tolist()[0]
nodes_df.loc[len(nodes_df)] = {
"Label": name + str(fcounts),
"Color": color_dict[source],
"Level": 3,
"Source": source,
"Value": fcounts,
}
# remaining unmatched cells (known or unknown)
if len(rem_col_counts) > 1:
rem_unmatched_cells = rem_col_counts[
rem_col_counts[rem_col_names[0]] == True
]["count"].values.tolist()[0]
unknown_unmatched_cells[source] = rem_col_df[
rem_col_df[rem_col_names[0]].isna() == True
].index.values.tolist()
name = " ".join(rem_col_names[-1].split("_"))
nodes_df.loc[len(nodes_df)] = {
"Label": name + str(rem_unmatched_cells),
"Color": color_dict[source],
"Level": 3,
"Source": source,
"Value": rem_unmatched_cells,
}
if save:
nodes_df.to_csv(savepath + "/sankey_nodes_df" + barcode + ".csv")
return nodes_df, unknown_unmatched_cells


def create_link_df_sankey(
nodes_df,
barcode,
save=True,
savepath="/allen/programs/celltypes/workgroups/hct/emilyg/reseg_project/new_seg/",
):
"""
Generates links dataframe from nodes dataframe, connecting relevant nodes to one another and ensuring distinct segmentations are separated.
INPUTS
nodes_df: dataframe containing information on each node for sankey diagram.
barcode: unique identifier for section, used for creating the save name for the dataframe, string.
save: whether to save the final results. Bool, default true
savepath: path to which results should be saved, string
OUTPUTS
links_df: dataframe containing source, target, value, and colors for sankey diagram creation
"""

link_color_dict = {"red": "rgb(205, 209, 228)", "blue": "rgb(205, 209, 228)"}
links_df = pd.DataFrame(columns=["Source", "Target", "Value", "Link Color"])
for i, row in nodes_df.iterrows():
if "total" in row.Label:
continue # skip 0 since all comes from zero
else:
target = i
value = row["Value"]
source = nodes_df[
(nodes_df["Level"] == row["Level"] - 1)
& (nodes_df["Source"] == row["Source"])
].index.values.tolist()
if len(source) > 1 and row["Level"] != 3:
source = [
i
for i, x in nodes_df.iloc[source, :].iterrows()
if "un" not in x["Label"]
if "low" not in x["Label"]
][0]
elif len(source) > 1 and row["Level"] == 3:
source = [
i
for i, x in nodes_df.iloc[source, :].iterrows()
if "un" in x["Label"]
][0]
else:
source = source[0]
prev_row = nodes_df.iloc[source, :]
link_color = link_color_dict[prev_row["Color"]]
links_df.loc[len(links_df)] = {
"Source": source,
"Target": target,
"Value": value,
"Link Color": link_color,
}

# special case: matched cells should have same name, new source, new color
matched_cell_rows = nodes_df[
(nodes_df["Label"].str.contains("matched cells"))
& (~nodes_df["Label"].str.contains("un"))
]
matched_cell_val = matched_cell_rows["Value"].values.tolist()[0]
nodes_df.loc[len(nodes_df)] = {
"Label": "matched cells <br> " + str(matched_cell_val),
"Color": "violet",
"Level": matched_cell_rows["Level"].values.tolist()[0],
"Source": "Both",
"Value": matched_cell_val,
}
# find matched col for both in nodes_df, then find rows with those IDS in links df and connect them
update_idxs = links_df[
links_df["Target"].isin(matched_cell_rows.index.values.tolist())
].index.values.tolist()
links_df.loc[update_idxs, "Target"] = [
nodes_df[nodes_df["Source"] == "Both"].index.values
] * len(update_idxs)
if save:
links_df.to_csv(savepath + "/sankey_links_df.csv")
return links_df

0 comments on commit 2cf4568

Please sign in to comment.