Skip to content

Commit

Permalink
Merge pull request #163 from dattalab/sleap_io_hotfix
Browse files Browse the repository at this point in the history
correct processing of axis order from sleap_io.Labels.numpy()
  • Loading branch information
calebweinreb authored Sep 4, 2024
2 parents 706e79e + 042bc68 commit 7530ee9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
16 changes: 8 additions & 8 deletions keypoint_moseq/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,20 +1134,20 @@ def _sleap_loader(filepath, name):

bodyparts = slp_file.skeletons[0].node_names
arr = slp_file.numpy(return_confidence=True)
coords = arr[:, :, :-1]
confs = arr[:, :, -1]
coords = arr[:, :, :, :-1].transpose((1, 0, 2, 3))
confs = arr[:, :, :, -1].transpose((1, 0, 2))
else:
with h5py.File(filepath, "r") as f:
coords = f["tracks"][()]
confs = f["point_scores"][()]
coords = f["tracks"][()].transpose((0, 3, 2, 1))
confs = f["point_scores"][()].transpose((0, 2, 1))
bodyparts = [name.decode("utf-8") for name in f["node_names"]]

if coords.shape[0] == 1:
coordinates = {name: coords[0].T}
confidences = {name: confs[0].T}
coordinates = {name: coords[0]}
confidences = {name: confs[0]}
else:
coordinates = {f"{name}_track{i}": coords[i].T for i in range(coords.shape[0])}
confidences = {f"{name}_track{i}": confs[i].T for i in range(coords.shape[0])}
coordinates = {f"{name}_track{i}": coords[i] for i in range(coords.shape[0])}
confidences = {f"{name}_track{i}": confs[i] for i in range(coords.shape[0])}
return coordinates, confidences, bodyparts


Expand Down
53 changes: 53 additions & 0 deletions keypoint_moseq/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,3 +2913,56 @@ def plot_eml_scores(eml_scores, eml_std_errs, model_names):
ax.set_ylabel("EML score")
plt.tight_layout()
return fig, ax


def plot_pose(
coordinates, bodyparts, skeleton, cmap="autumn", node_size=6, linewidth=3, ax=None
):
"""
Plot a single pose using matplotlib.
Parameters
----------
coordinates: ndarray of shape (num_bodyparts, 2)
2D coordinates of the pose.
bodyparts: list of str
Bodypart names.
skeleton: list of tuples
Skeleton edges as pairs of bodypart names.
cmap: str, default='autumn'
Colormap to use for coloring keypoints.
node_size: float, default=6
Size of keypoints.
linewidth: float, default=3
Width of skeleton edges.
ax: matplotlib axis, default=None
Axis to plot on. If None, a new axis is created.
Returns
-------
ax: matplotlib axis
Axis containing the plot.
"""
if ax is None:
fig, ax = plt.subplots(1, 1)

cmap = plt.get_cmap(cmap)
colors = cmap(np.linspace(0, 1, len(bodyparts)))
edges = get_edges(bodyparts, skeleton)

for i, (x, y) in enumerate(coordinates):
ax.scatter(x, y, s=node_size, c=[colors[i]])

for i, j in edges:
x = [coordinates[i, 0], coordinates[j, 0]]
y = [coordinates[i, 1], coordinates[j, 1]]
ax.plot(x, y, c=colors[i], linewidth=linewidth)

ax.set_aspect("equal")
return ax

0 comments on commit 7530ee9

Please sign in to comment.