Skip to content

Commit

Permalink
Merge pull request #16 from YosefLab/plot-sparse
Browse files Browse the repository at this point in the history
Plot sparse
  • Loading branch information
colganwi authored Oct 25, 2024
2 parents dc22fa0 + f18cccf commit 771b420
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ fail_fast: false
default_language_version:
python: python3
default_stages:
- commit
- push
- pre-commit
- pre-push
minimum_pre_commit_version: 2.16.0
repos:
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
5 changes: 4 additions & 1 deletion src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def annotation(
width: int | float = 0.05,
gap: int | float = 0.03,
label: bool | str | Sequence[str] = True,
layer: str | None = None,
border_width: int | float = 0,
tree: str | Sequence[str] | None = None,
cmap: str | mcolors.Colormap = None,
Expand All @@ -386,6 +387,8 @@ def annotation(
label
Annotation labels. If `True`, the keys are used as labels.
If a string or a sequence of strings, the strings are used as labels.
layer
Name of the TreeData object layer to use. If `None`, `tdata.X` is plotted.
border_width
The width of the border around the annotation bar.
tree
Expand Down Expand Up @@ -416,7 +419,7 @@ def annotation(
cmap = plt.get_cmap(cmap)
leaves = attrs["leaves"]
# Get data
data, is_array = get_keyed_obs_data(tdata, keys)
data, is_array = get_keyed_obs_data(tdata, keys, layer=layer)
numeric_data = data.select_dtypes(exclude="category")
if len(numeric_data) > 0 and not vmin:
vmin = numeric_data.min().min()
Expand Down
10 changes: 8 additions & 2 deletions src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,16 @@ def get_keyed_obs_data(tdata: td.TreeData, keys: str | Sequence[str], layer: str
data.append(pd.Series(tdata.obs_vector(key, layer=layer), index=tdata.obs_names))
column_keys = True
elif "obsm" in dir(tdata) and key in tdata.obsm.keys():
data.append(pd.DataFrame(tdata.obsm[key], index=tdata.obs_names))
if sp.sparse.issparse(tdata.obsm[key]):
data.append(pd.DataFrame(tdata.obsm[key].toarray(), index=tdata.obs_names))
else:
data.append(pd.DataFrame(tdata.obsm[key], index=tdata.obs_names))
array_keys = True
elif "obsp" in dir(tdata) and key in tdata.obsp.keys():
data.append(pd.DataFrame(tdata.obsp[key], index=tdata.obs_names, columns=tdata.obs_names))
if sp.sparse.issparse(tdata.obsp[key]):
data.append(pd.DataFrame(tdata.obsp[key].toarray(), index=tdata.obs_names, columns=tdata.obs_names))
else:
data.append(pd.DataFrame(tdata.obsp[key], index=tdata.obs_names, columns=tdata.obs_names))
array_keys = True
else:
raise ValueError(
Expand Down
8 changes: 0 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
import pytest
import treedata as td

_tdata = td.read_h5ad("tests/data/tdata.h5ad")


@pytest.fixture(scope="session")
def tdata() -> td.TreeData:
return _tdata
10 changes: 9 additions & 1 deletion tests/test_plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import matplotlib.pyplot as plt
import pytest
import treedata as td

import pycea

plot_path = Path(__file__).parent / "plots"


@pytest.fixture
def tdata() -> td.TreeData:
return td.read_h5ad("tests/data/tdata.h5ad")


def test_polar_with_clades(tdata):
fig, ax = plt.subplots(dpi=300, subplot_kw={"polar": True})
pycea.pl.branches(
Expand All @@ -32,7 +38,7 @@ def test_angled_numeric_annotations(tdata):


def test_matrix_annotation(tdata):
fig, ax = plt.subplots(dpi=300)
fig, ax = plt.subplots(dpi=300, figsize=(7, 3))
pycea.pl.tree(
tdata,
nodes="internal",
Expand All @@ -42,6 +48,8 @@ def test_matrix_annotation(tdata):
keys=["spatial_distances"],
ax=ax,
)
pycea.tl.tree_neighbors(tdata, max_dist=5, depth_key="time",update=False)
pycea.pl.annotation(tdata, keys="tree_connectivities", ax=ax,cmap = "Purples")
plt.savefig(plot_path / "matrix_annotation.png")
plt.close()

Expand Down
20 changes: 20 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import networkx as nx
import numpy as np
import pandas as pd
import pytest
import scipy as sp
import treedata as td

from pycea.utils import (
Expand Down Expand Up @@ -30,11 +32,14 @@ def tree():
@pytest.fixture
def tdata(tree):
tdata = td.TreeData(
X=np.array([[1, 2], [3, 4]]),
obs=pd.DataFrame({"value": ["1", "2"]}, index=["D", "E"]),
obst={"tree": tree, "tree2": tree},
obsm={"spatial": pd.DataFrame([[0, 0], [1, 1]], index=["D", "E"])},
obsp={"dense":np.eye(2), "sparse":sp.sparse.csr_matrix(np.eye(2))},
allow_overlap=True,
)
tdata.layers["scaled"] = tdata.X
yield tdata


Expand Down Expand Up @@ -98,15 +103,30 @@ def test_get_keyed_obs_data_valid_keys(tdata):
# Automatically converts object columns to category
assert data["value"].dtype == "category"
assert tdata.obs["value"].dtype == "category"
# Gets gene expression data
data, is_array = get_keyed_obs_data(tdata, "1", layer="scaled")
assert not is_array
assert data["1"].tolist() == [2, 4]


def test_get_keyed_obs_data_array(tdata):
# spatial obsm array
data, is_array = get_keyed_obs_data(tdata, ["spatial"])
assert data.columns.tolist() == [0, 1]
assert data[0].tolist() == [0, 1]
assert is_array
assert isinstance(data, pd.DataFrame)
assert data.shape[1] == 2
# dense obsp array
data, is_array = get_keyed_obs_data(tdata, ["dense"])
assert is_array
assert isinstance(data, pd.DataFrame)
assert data.shape[1] == 2
# sparse obsp array
data, is_array = get_keyed_obs_data(tdata, ["sparse"])
assert is_array
assert isinstance(data, pd.DataFrame)
assert data.shape[1] == 2


def test_get_keyed_obs_data_invalid_keys(tdata):
Expand Down

0 comments on commit 771b420

Please sign in to comment.