diff --git a/spatial_compare/__init__.py b/spatial_compare/__init__.py index db77036..8a5bc57 100644 --- a/spatial_compare/__init__.py +++ b/spatial_compare/__init__.py @@ -1,7 +1,2 @@ from .spatial_compare import SpatialCompare, get_column_ordering -from .utils import ( - grouped_obs_mean, - spatial_detection_scores, - summarize_and_plot, - compare_reference_and_spatial, -) +from .utils import grouped_obs_mean,spatial_detection_scores,summarize_and_plot, compare_reference_and_spatial, detect_outliers diff --git a/spatial_compare/utils.py b/spatial_compare/utils.py index 1e4ca02..5c7cf40 100644 --- a/spatial_compare/utils.py +++ b/spatial_compare/utils.py @@ -5,6 +5,7 @@ import seaborn as sns + def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None): """ Calculate the mean expression of observations grouped by a specified key. @@ -25,6 +26,7 @@ def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None): and rows correspond to genes. """ + if layer is not None: getX = lambda x: x.layers[layer] else: @@ -38,30 +40,22 @@ def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None): out = pd.DataFrame( np.zeros((adata.shape[1], len(grouped)), dtype=np.float64), columns=list(grouped.groups.keys()), - index=new_idx, + index=new_idx ) for group, idx in grouped.indices.items(): X = getX(adata[idx]) - out[group] = np.ravel(X.mean(axis=0, dtype=np.float64)).tolist() - + out[group] = np.ravel(X.mean(axis=0, dtype=np.float64)).tolist() + return out - -def spatial_detection_scores( - reference: pd.DataFrame, - query: pd.DataFrame, - plot_stuff=True, - query_name: str = "query data", - comparison_column="transcript_counts", - category="supercluster_name", - n_bins=50, - in_place=True, - non_spatial=False, -): +def spatial_detection_scores(reference: pd.DataFrame, query: pd.DataFrame, + plot_stuff=True,query_name: str="query data",comparison_column="transcript_counts", + category = "supercluster_name", + n_bins = 50, in_place=True, non_spatial = False): """ Calculate and plot spatial detection scores for query data compared to reference data. - + Parameters: reference (pd.DataFrame): The reference data. query (pd.DataFrame): The query data. @@ -78,146 +72,117 @@ def spatial_detection_scores( # code goes here if category not in reference.columns or category not in query.columns: - raise ValueError("category " + category + " not in reference and query inputs") + raise ValueError("category "+category+" not in reference and query inputs") - shared_category_values = list( - set(reference[category].unique()) & set(query[category].unique()) - ) - if ( - len(shared_category_values) < query[category].unique().shape[0] - or len(shared_category_values) < reference[category].unique().shape[0] - ): - in_place = False + shared_category_values = list(set(reference[category].unique())& set(query[category].unique())) + if (len(shared_category_values) min_cells_per_bin - bad_frac = np.sum( - spatial_density_results[z]["z_score_image"][tissue_bins] <= z_score_limit - ) / np.sum(tissue_bins) + for ii,z in enumerate(sorted(list(spatial_density_results.keys()))): + #estimate tissue bins: + tissue_bins = spatial_density_results[z]["count_image"]>min_cells_per_bin + bad_frac = np.sum(spatial_density_results[z]["z_score_image"][tissue_bins]<=z_score_limit)/np.sum(tissue_bins) image_to_show = spatial_density_results[z]["z_score_image"].copy() - image_to_show[np.logical_not(tissue_bins)] = 0 + image_to_show[np.logical_not(tissue_bins)]=0 + if plot_stuff: if z in title_mapping: title_prefix = title_mapping[z] else: - title_prefix = "" - - plt.subplot( - 1 + len(list(spatial_density_results.keys())) // plot_columns, - plot_columns, - ii + 1, - ) - - ax = plt.imshow( - image_to_show, - vmin=-1, - vmax=1, - extent=spatial_density_results[z]["extent"], - cmap="coolwarm_r", - ) - - if bad_frac >= area_frac_limit: - plt.title(title_prefix + " Fail", fontdict={"size": 12}) + title_prefix="" + + plt.subplot(1+len(list(spatial_density_results.keys()))//plot_columns,plot_columns, ii+1) + + ax = plt.imshow(image_to_show, vmin = -1, vmax=1, extent = spatial_density_results[z]["extent"], cmap = 'coolwarm_r') + + if bad_frac>=area_frac_limit: + plt.title(title_prefix+" Fail", fontdict={"size":12}) else: - plt.title(title_prefix, fontdict={"size": 12}) + plt.title(title_prefix, fontdict={"size":12}) plt.xticks([]) plt.yticks([]) - - if bad_frac >= area_frac_limit: + + if bad_frac>=area_frac_limit: fails_vs_rnaseq.append(z) - results.append( - dict( - key=z, - bad_frac=bad_frac, - area_frac_limit=area_frac_limit, - failed=bad_frac >= area_frac_limit, - mean=np.mean(image_to_show[image_to_show != 0]), - stdev=np.std(image_to_show[image_to_show != 0]), - ) - ) + results.append(dict(key=z, + bad_frac=bad_frac, + area_frac_limit = area_frac_limit, + failed = bad_frac>=area_frac_limit, + mean = np.mean(image_to_show[image_to_show!=0]), + stdev = np.std(image_to_show[image_to_show!=0]))) return results -def compare_reference_and_spatial( - reference_anndata, - spatial_anndata, - category="MTG_subclass_name", - layer_field=None, - plot_stuff=True, - target_obs_key="comparison_transcript_counts", - ok_to_clobber=False, -): +def compare_reference_and_spatial(reference_anndata,spatial_anndata, + category="MTG_subclass_name", layer_field = None, + plot_stuff=True, + target_obs_key = "comparison_transcript_counts", ok_to_clobber=False): """ Compare reference and spatial data based on a specified category. Parameters: @@ -322,59 +268,237 @@ def compare_reference_and_spatial( """ if target_obs_key in reference_anndata.obs.columns: if not ok_to_clobber == True: - raise ValueError( - "obs key " - + target_obs_key - + " is already in the input reference .obs\n If desired, set ok_to_clobber to True" - ) + raise ValueError("obs key "+target_obs_key+" is already in the input reference .obs\n If desired, set ok_to_clobber to True") else: - print( - "warning: modifying input reference anndata .obs field " - + target_obs_key - ) + print("warning: modifying input reference anndata .obs field "+target_obs_key) if layer_field is not None: - reference_anndata.obs[target_obs_key] = np.sum( - reference_anndata.layers[layer_field], axis=1 - ) + reference_anndata.obs[target_obs_key] = np.sum(reference_anndata.layers[layer_field],axis=1) else: - reference_anndata.obs[target_obs_key] = np.sum(reference_anndata.X, axis=1) + reference_anndata.obs[target_obs_key] = np.sum(reference_anndata.X,axis=1) - means = ( - reference_anndata.obs.loc[:, [category, target_obs_key]] - .groupby(category, observed=True) - .mean() - ) + + means = reference_anndata.obs.loc[:,[category,target_obs_key]].groupby(category, observed=True).mean() means["spatial_counts"] = 0.0 + for name in spatial_anndata.obs[category].unique(): - means.loc[name, ["spatial_counts"]] = ( - spatial_anndata.obs.loc[ - spatial_anndata.obs[category] == name, [target_obs_key] - ] - .mean() - .values[0] - ) + means.loc[name,["spatial_counts"]] = spatial_anndata.obs.loc[spatial_anndata.obs[category]==name,[target_obs_key]].mean().values[0] + + - fit_values = np.polyfit(means.spatial_counts, means[target_obs_key], 1) + fit_values = np.polyfit(means.spatial_counts,means[target_obs_key], 1) - scale_factor = 1 / fit_values[0] + scale_factor = 1/fit_values[0] if plot_stuff: plt.figure() - plt.plot(means.spatial_counts, means[target_obs_key], ".") + plt.plot(means.spatial_counts, means[target_obs_key],'.') plt.xlabel("spatial counts") - plt.ylabel(target_obs_key + " in reference anndata") + plt.ylabel(target_obs_key+" in reference anndata") plt.figure() - sns.regplot(x=means.spatial_counts, y=scale_factor * means[target_obs_key]) - plt.plot([0, 1000], [0, 1000], label="unity") - # plt.plot(means.spatial_counts, means.spatial_counts*fit_values[0]+fit_values[1], label = "fit") + sns.regplot(x=means.spatial_counts, y=scale_factor*means[target_obs_key]) + plt.plot([0,1000],[0,1000], label="unity") + #plt.plot(means.spatial_counts, means.spatial_counts*fit_values[0]+fit_values[1], label = "fit") plt.ylabel("scaled reference data") - plt.title("\n scale = " + str(scale_factor)[:6]) + plt.title("\n scale = "+str(scale_factor)[:6]) plt.legend() + + + + # scale transcript counts based on the linear fit. # could also do this per group. - reference_anndata.obs[target_obs_key] = ( - scale_factor * reference_anndata.obs[target_obs_key] - ) + reference_anndata.obs[target_obs_key] = scale_factor*reference_anndata.obs[target_obs_key] + + + +def detect_outliers(adata1, adata2, transform_type='none', gene_name='', inlier_threshold=0.7, display_others_plot=False, display_outlier_plot=False): + """ + Detect outliers in gene expression data using RANSAC regression. + + Parameters: + - adata1: AnnData object for dataset 1. + - adata2: AnnData object for dataset 2. + - transform_type: Type of transformation to apply ('scale', 'log', or 'none'). + - gene_name: Specific gene to visualize (if empty, visualizes the first inlier). + - inlier_threshold: Proportion of inliers required to classify a gene as an inlier. It’s a percentage of cells that got an "Inlier" label from RANSAC for a given gene based on their expression level. If your proportion_inliers >= inlier_threshold, the gene going through the loop will be labeled as an inlier. Otherwise, it'll be labeled as an outlier. + - display_others_plot: Boolean indicating whether to display the plot for a specific gene. Work in progress. + - display_outlier_plot: Boolean indicating whether to display the outlier plot. + + Returns: + - Prints lists of inlier and outlier genes. + """ + + # Function to ensure dense matrix + def ensure_dense(matrix): + return matrix.toarray() if hasattr(matrix, 'toarray') else matrix + + # Prepare inputs + X = ensure_dense(adata1.X) # Convert adata1.X to dense if necessary + Y = ensure_dense(adata2.X) # Convert adata2.X to dense if necessary + + genes1 = adata1.var_names + genes2 = adata2.var_names + + # Create DataFrames + df1 = pd.DataFrame(X, columns=genes1) + df2 = pd.DataFrame(Y, columns=genes2) + + # Remove genes containing 'UnassignedCodeword' from both datasets + mask1 = ~df1.columns.str.contains('UnassignedCodeword|NegControlProbe|NegControlCodeword', case=False) + mask2 = ~df2.columns.str.contains('UnassignedCodeword|NegControlProbe|NegControlCodeword', case=False) + + df1 = df1.loc[:, mask1] + df2 = df2.loc[:, mask2] + + # Find common genes after filtering + common_genes = list(set(df1.columns).intersection(set(df2.columns))) + + # Initialize lists to store results + inlier_genes = [] + outlier_genes = [] + + # Analyze each common gene independently + for gene in common_genes: + # Extract expression levels for the current gene + X_gene = df1[gene].dropna().to_numpy() # Drop NaN values if any + Y_gene = df2[gene].dropna().to_numpy() # Drop NaN values if any + + # If either dataset has no data for this gene, skip it + if len(X_gene) == 0 or len(Y_gene) == 0: + continue + + # Apply transformations based on the transform_type parameter + if transform_type == 'log': + X_gene = np.log1p(X_gene) # log(1 + x) handles zero values safely + Y_gene = np.log1p(Y_gene) + elif transform_type == 'scale': + scaler = StandardScaler() + X_gene = scaler.fit_transform(X_gene.reshape(-1, 1)) + Y_gene = scaler.fit_transform(Y_gene.reshape(-1, 1)) + + # Use only the minimum length of the two arrays for fitting + min_length = min(len(X_gene), len(Y_gene)) + + # Fit RANSAC regressor for the current gene using only available data + ransac = RANSACRegressor(random_state=42) + ransac.fit(X_gene[:min_length].reshape(-1, 1), Y_gene[:min_length].reshape(-1, 1)) + + # Extract inliers and outliers for this gene + inlier_mask = ransac.inlier_mask_ + + # Calculate proportion of inliers + proportion_inliers = np.sum(inlier_mask) / min_length + + # Check if proportion of inliers meets threshold + if proportion_inliers >= inlier_threshold: + inlier_genes.append(gene) + else: + outlier_genes.append(gene) + + # Select a specific gene to visualize (if provided) + selected_gene = gene_name if gene_name else (inlier_genes[0] if inlier_genes else None) + + if selected_gene and display_others_plot: + X_selected = df1[selected_gene].dropna().to_numpy().reshape(-1, 1) + Y_selected = df2[selected_gene].dropna().to_numpy().reshape(-1, 1) + + # Apply transformations for visualization based on transform_type + if transform_type == 'log': + X_selected = np.log1p(X_selected) + Y_selected = np.log1p(Y_selected) + elif transform_type == 'scale': + scaler = StandardScaler() + X_selected = scaler.fit_transform(X_selected) + Y_selected = scaler.fit_transform(Y_selected) + + # Align lengths for visualization + min_length = min(len(X_selected), len(Y_selected)) + + ransac.fit(X_selected[:min_length], Y_selected[:min_length]) + inlier_mask_visual = ransac.inlier_mask_ + + plt.figure(figsize=(8, 5)) + plt.scatter(X_selected[inlier_mask_visual], Y_selected[inlier_mask_visual], color='blue', label='Inliers') + plt.scatter(X_selected[~inlier_mask_visual], Y_selected[~inlier_mask_visual], color='red', label='Outliers') + + # Plotting the RANSAC fit line based on transformed data + plt.plot(X_selected, ransac.predict(X_selected), color='green', label='RANSAC fit') + + plt.xlabel(f'Gene Expression (Dataset 1) - {selected_gene}') + plt.ylabel(f'Gene Expression (Dataset 2) - {selected_gene}') + plt.title(f'RANSAC Outlier Detection for {selected_gene}') + plt.legend() + plt.show() + else: + print("No inliers found to visualize or plotting is disabled.") + + # Print outlier and inlier genes + print("Inlier Genes:", inlier_genes) + print("Inlier nGenes:", len(inlier_genes)) + print("Outlier Genes:", outlier_genes) + print("Outlier nGenes:", len(outlier_genes)) + + if display_outlier_plot: + + # Step 1: Extract common genes + common_genes = adata1.var_names.intersection(adata2.var_names) + + # Step 2: Sort common genes + sorted_common_genes = sorted(common_genes) + + # Step 3: Reorder AnnData objects based on sorted common genes + adata1_sorted = adata1[:, sorted_common_genes] + adata2_sorted = adata2[:, sorted_common_genes] + + # Step 4: Calculate average expression for sorted common genes + avg_expr_adata1 = np.asarray(adata1_sorted.X.mean(axis=0)).flatten() + avg_expr_adata2 = np.asarray(adata2_sorted.X.mean(axis=0)).flatten() + + # Check shapes + print("Shape of avg_expr_adata1:", avg_expr_adata1.shape) + print("Shape of avg_expr_adata2:", avg_expr_adata2.shape) + + # Step 5: Log-transform the average expression values + log_avg_expr_adata1 = np.log1p(avg_expr_adata1) # log(1 + x) + log_avg_expr_adata2 = np.log1p(avg_expr_adata2) + + # Step 6: Create a DataFrame with ordered data + data = pd.DataFrame({ + 'gene': sorted_common_genes, + 'adata1_avg': log_avg_expr_adata1, + 'adata2_avg': log_avg_expr_adata2 + }) + + # Step 7: Create a new column for color based on outlier genes + data['color'] = np.where(data['gene'].isin(outlier_genes), 'red', 'blue') + + # Step 8: Create scatter plot with colors for outliers and inliers + plt.figure(figsize=(8, 6)) + + sns.scatterplot(data=data, x='adata1_avg', y='adata2_avg', hue='color', palette={'red': 'red', 'blue': 'blue'}, legend=False) + + # Step 9: Annotate outlier genes on the plot + for index, row in data.iterrows(): + if row['color'] == 'red': + plt.text(row['adata1_avg'], row['adata2_avg'], row['gene'], fontsize=9, ha='right', va='bottom') + + # Step 10: Calculate and display correlation coefficient on log-transformed data + correlation = np.corrcoef(data['adata1_avg'], data['adata2_avg'])[0, 1] + plt.title(f'Scatter Plot of Log-Transformed Average Gene Expression\nCorrelation: {correlation:.2f}') + plt.xlabel('Log-Transformed Average Expression in adata1') + plt.ylabel('Log-Transformed Average Expression in adata2') + plt.grid(True) + + # Custom legend handles to match colors + legend_elements = [ + Line2D([0], [0], marker='o', color='w', label='Inlier Genes', markerfacecolor='blue', markersize=10), + Line2D([0], [0], marker='o', color='w', label='Outlier Genes', markerfacecolor='red', markersize=10) + ] + + plt.legend(handles=legend_elements, title='Gene Type') + plt.show() + + return inlier_genes, outlier_genes \ No newline at end of file