Skip to content

Commit

Permalink
Merge pull request #146 from dattalab/better_error_catching
Browse files Browse the repository at this point in the history
Better error catching + change default iters for apply_model
  • Loading branch information
calebweinreb authored Apr 16, 2024
2 parents af0988a + a4bc777 commit a4b3766
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 8 deletions.
36 changes: 36 additions & 0 deletions docs/source/FAQs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,42 @@ The final output of keypoint MoSeq is a results .h5 file (and optionally a direc
Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model.


Validating results when applying a model to new data
---------------------------------------------------
When applying a model to new data, it may be useful to generate new grid movies and trajectory plots so you can confirm that the meaning of the syllables has been preserved. Let's say you've already applied the model to new data as follows:

.. code-block:: python
# load new data (e.g. from deeplabcut)
coordinates, confidences, bodyparts = kpms.load_keypoints(new_data_path, 'deeplabcut')
data, metadata = kpms.format_data(coordinates, confidences, **config())
# apply saved model to new data
results = kpms.apply_model(model, data, metadata, project_dir, model_name, **config())
By default, the `results` dictionary above contains results for both the new and old data. To generate grid movies and trajectory plots for the new data only, we can subset the `results` dictionary to include only the new data. We will also need to specify alternative paths for saving the new movies and plots so the original ones aren't overwritten.

.. code-block:: python
import os
# only include results for the new data
new_results = {k:v for k,v in results.items() if k in coordinates}
# save trajectory plots for the new data
output_dir = os.path.join(project_dir, model_name, "new_trajectory_plots")
kpms.generate_trajectory_plots(
coordinates, new_results, project_dir,model_name, output_dir=output_dir, **config()
)
# save grid movies for the new data
output_dir = os.path.join(project_dir, model_name, "new_grid_movies")
kpms.generate_grid_movies(
new_results, project_dir, model_name, coordinates=coordinates, output_dir=output_dir, **config()
);
Visualization
=============

Expand Down
6 changes: 3 additions & 3 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ Model averaging

Keypoint-MoSeq is probabilistic. So even once fitting is complete and the syllable parameters are fixed, there is still a distribution of possible syllable sequences given the observed data. In the default pipeline, one such sequence is sampled from this distribution and used for downstream analyses. Alternatively, one can estimate the marginal probability distribution over syllable labels at each timepoint. The code below shows how to do this. It can be applied to new data or the same data that was used for fitting (or a combination of the two).::

burnin_iters = 50
burnin_iters = 200
num_samples = 100
steps_per_sample = 5

Expand Down Expand Up @@ -273,8 +273,8 @@ Temporal downsampling
Sometimes it's useful to downsample a dataset, e.g. if the original recording has a much higher framerate than is needed for modeling. To downsample, run the following lines right after loading the keypoints.::

downsample_rate = 2 # keep every 2nd frame
kpms.downsample_timepoints(coordinates, downsample_rate)
kpms.downsample_timepoints(confidences, downsample_rate) # skip if `confidences=None`
coordinates = kpms.downsample_timepoints(coordinates, downsample_rate)
confidences = kpms.downsample_timepoints(confidences, downsample_rate) # skip if `confidences=None`

After this, the pipeline can be run as usual, except for steps that involve reading the original videos, in which case ``downsample_rate`` should be passed as an additional argument.::

Expand Down
8 changes: 4 additions & 4 deletions keypoint_moseq/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def apply_model(
metadata,
project_dir=None,
model_name=None,
num_iters=50,
num_iters=500,
ar_only=False,
save_results=True,
verbose=False,
Expand Down Expand Up @@ -321,7 +321,7 @@ def apply_model(
Name of the model. Required if `save_results=True` and
`results_path=None`.
num_iters : int, default=50
num_iters : int, default=500
Number of iterations to run the model.
ar_only : bool, default=False
Expand Down Expand Up @@ -416,7 +416,7 @@ def estimate_syllable_marginals(
model,
data,
metadata,
burn_in_iters=50,
burn_in_iters=200,
num_samples=100,
steps_per_sample=10,
return_samples=False,
Expand All @@ -440,7 +440,7 @@ def estimate_syllable_marginals(
Recordings and start/end frames for the data (see
:py:func:`keypoint_moseq.io.format_data`).
burn_in_iters : int, default=50
burn_in_iters : int, default=200
Number of resampling iterations to run before collecting samples.
num_samples : int, default=100
Expand Down
53 changes: 53 additions & 0 deletions keypoint_moseq/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scipy.spatial.distance import pdist, squareform
from jax_moseq.models.keypoint_slds import inverse_rigid_transform
from jax_moseq.utils import get_frequencies, batch
from vidio.read import OpenCVReader

na = jnp.newaxis

Expand Down Expand Up @@ -1177,3 +1178,55 @@ def downsample_timepoints(data, downsample_rate):
return {k: downsample_timepoints(v, downsample_rate) for k, v in data.items()}
else:
return data[::downsample_rate]


def check_video_paths(video_paths, keys):
"""
Check if video paths are valid and match the keys.
Parameters
----------
video_paths: dict
Dictionary mapping keys to video paths.
keys: list
List of keys that require a video path.
Raises
------
ValueError
If any of the following are true:
- a video path is not provided for a key in `keys`
- a video isn't readable.
- a video path does not exist.
"""
missing_keys = set(keys) - set(video_paths.keys())

nonexistent_videos = []
unreadable_videos = []
for path in video_paths.values():
if not os.path.exists(path):
nonexistent_videos.append(path)
else:
try:
OpenCVReader(path)[0]
except:
unreadable_videos.append(path)

error_messages = []

if len(missing_keys) > 0:
error_messages.append(
"The following keys require a video path: {}".format(missing_keys)
)
if len(nonexistent_videos) > 0:
error_messages.append(
"The following videos do not exist: {}".format(nonexistent_videos)
)
if len(unreadable_videos) > 0:
error_messages.append(
"The following videos are not readable and must be reencoded: {}".format(unreadable_videos)
)

if len(error_messages) > 0:
raise ValueError("\n\n".join(error_messages))
1 change: 1 addition & 0 deletions keypoint_moseq/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ def generate_grid_movies(
as_dict=True,
video_extension=video_extension,
)
check_video_paths(video_paths, results.keys())
videos = {k: OpenCVReader(path) for k, path in video_paths.items()}

if fps is None:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ install_requires =
tabulate
commentjson
jaxtyping==0.2.14
jax-moseq==0.2.1
jax-moseq==0.2.2

[options.package_data]
* = *.md
Expand Down

0 comments on commit a4b3766

Please sign in to comment.