diff --git a/conda_envs/environment.linux_gpu.yml b/conda_envs/environment.linux_gpu.yml index dca1766..544f77c 100644 --- a/conda_envs/environment.linux_gpu.yml +++ b/conda_envs/environment.linux_gpu.yml @@ -14,5 +14,6 @@ dependencies: - pip - pip: - "keypoint-moseq" + - "jax-moseq[cuda11]" - jupyterlab - etils==1.5.2 \ No newline at end of file diff --git a/conda_envs/environment.win64_gpu.yml b/conda_envs/environment.win64_gpu.yml index e9ff628..5eca66b 100644 --- a/conda_envs/environment.win64_gpu.yml +++ b/conda_envs/environment.win64_gpu.yml @@ -16,4 +16,5 @@ dependencies: - jax==0.3.22 - "https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl" - "keypoint-moseq" + - "jax-moseq[cuda11]" - jupyterlab diff --git a/docs/keypoint_moseq_colab.ipynb b/docs/keypoint_moseq_colab.ipynb index 0587f98..6810848 100644 --- a/docs/keypoint_moseq_colab.ipynb +++ b/docs/keypoint_moseq_colab.ipynb @@ -1,662 +1,661 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "preliminary-agency", - "metadata": {}, - "source": [ - "This notebook shows how to setup a new project, train a keypoint-MoSeq model and visualize the resulting syllables. \n", - "\n", - "**Total run time: ~90 min.**\n", - "\n", - "# Colab setup\n", - "\n", - "- Make a copy of this notebook if you plan to make changes and want them saved.\n", - "- Go to \"Runtime\">\"change runtime type\" and select \"Python 3\" and \"GPU\"" - ] - }, - { - "cell_type": "markdown", - "id": "aaab50c4", - "metadata": {}, - "source": [ - "### Install keypoint MoSeq" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84990f8c", - "metadata": {}, - "outputs": [], - "source": [ - "! pip install -U git+https://github.com/dattalab/jax-moseq.git@installation_refactor\n", - "! pip install -U git+https://github.com/dattalab/keypoint-moseq.git@installation_refactor\n", - "\n", - "import os\n", - "from google.colab import drive\n", - "drive.mount('/content/drive')" - ] - }, - { - "cell_type": "markdown", - "id": "df94500c", - "metadata": {}, - "source": [ - "### Option 1: Use our example dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e944d0e1", - "metadata": {}, - "outputs": [], - "source": [ - "import gdown\n", - "url = 'https://drive.google.com/uc?id=1JGyS9MbdS3MtrlYnh4xdEQwe2bYoCuSZ'\n", - "output = 'dlc_example_project.zip'\n", - "gdown.download(url, output, quiet=False)\n", - "! unzip dlc_example_project.zip\n", - "\n", - "data_dir = \"dlc_example_project\"" - ] - }, - { - "cell_type": "markdown", - "id": "abad448e", - "metadata": {}, - "source": [ - "### Option 2: Use your own data\n", - "Upload your data to google drive and then change the following path as needed" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7df6bb2", - "metadata": {}, - "outputs": [], - "source": [ - "# data_dir = \"/content/drive/MyDrive/MY_DATA_DIRECTORY\"" - ] - }, - { - "cell_type": "markdown", - "id": "f8a52043", - "metadata": {}, - "source": [ - "# Project setup\n", - "Create a new project directory with a keypoint-MoSeq `config.yml` file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "intermediate-kenya", - "metadata": {}, - "outputs": [], - "source": [ - "import keypoint_moseq as kpms\n", - "\n", - "project_dir = '/content/drive/MyDrive/demo_project/'\n", - "config = lambda: kpms.load_config(project_dir)" - ] - }, - { - "cell_type": "markdown", - "id": "012d8287", - "metadata": {}, - "source": [ - "### Option 1: Setup from DeepLabCut" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b4aa2fcd", - "metadata": { - "mystnb": { - "code_prompt_hide": "Setup from DeepLabCut", - "code_prompt_show": "Setup from DeepLabCut" + "cells": [ + { + "cell_type": "markdown", + "id": "preliminary-agency", + "metadata": {}, + "source": [ + "This notebook shows how to setup a new project, train a keypoint-MoSeq model and visualize the resulting syllables. \n", + "\n", + "**Total run time: ~90 min.**\n", + "\n", + "# Colab setup\n", + "\n", + "- Make a copy of this notebook if you plan to make changes and want them saved.\n", + "- Go to \"Runtime\">\"change runtime type\" and select \"Python 3\" and \"GPU\"" + ] + }, + { + "cell_type": "markdown", + "id": "aaab50c4", + "metadata": {}, + "source": [ + "### Install keypoint MoSeq" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84990f8c", + "metadata": {}, + "outputs": [], + "source": [ + "! pip install -U keypoint-moseq", + "\n", + "import os\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ] + }, + { + "cell_type": "markdown", + "id": "df94500c", + "metadata": {}, + "source": [ + "### Option 1: Use our example dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e944d0e1", + "metadata": {}, + "outputs": [], + "source": [ + "import gdown\n", + "url = 'https://drive.google.com/uc?id=1JGyS9MbdS3MtrlYnh4xdEQwe2bYoCuSZ'\n", + "output = 'dlc_example_project.zip'\n", + "gdown.download(url, output, quiet=False)\n", + "! unzip dlc_example_project.zip\n", + "\n", + "data_dir = \"dlc_example_project\"" + ] + }, + { + "cell_type": "markdown", + "id": "abad448e", + "metadata": {}, + "source": [ + "### Option 2: Use your own data\n", + "Upload your data to google drive and then change the following path as needed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7df6bb2", + "metadata": {}, + "outputs": [], + "source": [ + "# data_dir = \"/content/drive/MyDrive/MY_DATA_DIRECTORY\"" + ] + }, + { + "cell_type": "markdown", + "id": "f8a52043", + "metadata": {}, + "source": [ + "# Project setup\n", + "Create a new project directory with a keypoint-MoSeq `config.yml` file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "intermediate-kenya", + "metadata": {}, + "outputs": [], + "source": [ + "import keypoint_moseq as kpms\n", + "\n", + "project_dir = '/content/drive/MyDrive/demo_project/'\n", + "config = lambda: kpms.load_config(project_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "012d8287", + "metadata": {}, + "source": [ + "### Option 1: Setup from DeepLabCut" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4aa2fcd", + "metadata": { + "mystnb": { + "code_prompt_hide": "Setup from DeepLabCut", + "code_prompt_show": "Setup from DeepLabCut" + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "dlc_config = os.path.join(data_dir, 'config.yaml')\n", + "kpms.setup_project(project_dir, deeplabcut_config=dlc_config)" + ] + }, + { + "cell_type": "markdown", + "id": "d8967d49", + "metadata": {}, + "source": [ + "### Option 2: Setup from SLEAP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c13b902", + "metadata": { + "mystnb": { + "code_prompt_hide": "Setup from SLEAP", + "code_prompt_show": "Setup from SLEAP" + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "# choose a .h5 file for one of your recordings\n", + "# sleap_file = os.path.join(data_dir, 'SLEAP_FILE_NAME') \n", + "# kpms.setup_project(project_dir, sleap_file=sleap_file)" + ] + }, + { + "cell_type": "markdown", + "id": "b9e62b8e", + "metadata": {}, + "source": [ + "### Options 3: Manual setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d804ac5", + "metadata": { + "mystnb": { + "code_prompt_hide": "Custom setup", + "code_prompt_show": "Custom setup" + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "# bodyparts=[\n", + "# 'tail', 'spine4', 'spine3', 'spine2', 'spine1',\n", + "# 'head', 'nose', 'right ear', 'left ear']\n", + "\n", + "# skeleton=[\n", + "# ['tail', 'spine4'],\n", + "# ['spine4', 'spine3'],\n", + "# ['spine3', 'spine2'],\n", + "# ['spine2', 'spine1'],\n", + "# ['spine1', 'head'],\n", + "# ['nose', 'head'],\n", + "# ['left ear', 'head'],\n", + "# ['right ear', 'head']]\n", + "\n", + "# video_dir = os.path.join(data_dir, 'videos')\n", + "\n", + "# kpms.setup_project(\n", + "# project_dir,\n", + "# video_dir=video_dir,\n", + "# bodyparts=bodyparts,\n", + "# skeleton=skeleton)" + ] + }, + { + "cell_type": "markdown", + "id": "gothic-viking", + "metadata": {}, + "source": [ + "## Edit the config file\n", + "\n", + "The config can be edited in a text editor or using the function `kpms.update_config`, as shown below. In general, the following parameters should be specified for each project:\n", + "\n", + "- `bodyparts` (name of each keypoint; automatically imported from SLEAP/DeepLabCut)\n", + "- `use_bodyparts` (subset of bodyparts to use for modeling, set to all bodyparts by default; for mice we recommend excluding the tail)\n", + "- `anterior_bodyparts` and `posterior_bodyparts` (used for rotational alignment)\n", + "- `video_dir` (directory with videos of each experiment)\n", + "\n", + "Edit the config as follows for the [example DeepLabCut dataset](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "theoretical-yahoo", + "metadata": {}, + "outputs": [], + "source": [ + "kpms.update_config(\n", + " project_dir,\n", + " video_dir=os.path.join(data_dir, 'videos'),\n", + " anterior_bodyparts=['nose'],\n", + " posterior_bodyparts=['spine4'],\n", + " use_bodyparts=[\n", + " 'spine4', 'spine3', 'spine2', 'spine1',\n", + " 'head', 'nose', 'right ear', 'left ear'])" + ] + }, + { + "cell_type": "markdown", + "id": "phantom-dating", + "metadata": {}, + "source": [ + "## Load data\n", + "\n", + "The code below shows how to load keypoint detections from DeepLabCut. To load other formats, replace `'deeplabcut'` in the example with one of `'sleap', 'anipose', 'sleap-anipose', 'nwb'`. For other formats, see the [FAQ](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#loading-keypoint-tracking-data)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "expressed-christian", + "metadata": {}, + "outputs": [], + "source": [ + "# load data (e.g. from DeepLabCut)\n", + "keypoint_data_path = os.path.join(data_dir, 'videos') # can be a file, a directory, or a list of files\n", + "coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut')\n", + "\n", + "# format data for modeling\n", + "data, metadata = kpms.format_data(coordinates, confidences, **config())" + ] + }, + { + "cell_type": "markdown", + "id": "processed-struggle", + "metadata": {}, + "source": [ + "## Calibration [disabled in colab]\n", + "\n", + "The purpose of calibration is to learn the relationship between error and keypoint confidence scores. The resulting regression coefficients (`slope` and `intercept`) are used during modeling to set the noise prior on a per-frame, per-keypoint basis. **This step is disabled in colab**. In any case it can safely be skipped since the default parameters are fine for most datasets. " + ] + }, + { + "cell_type": "markdown", + "id": "organizational-theorem", + "metadata": {}, + "source": [ + "## Fit PCA\n", + "\n", + "Run the cell below to fit a PCA model to aligned and centered keypoint coordinates.\n", + "\n", + "- The model is saved to ``{project_dir}/pca.p`` and can be reloaded using ``kpms.load_pca``. \n", + "- Two plots are generated: a cumulative [scree plot](https://en.wikipedia.org/wiki/Scree_plot) and a depiction of each PC, where translucent nodes/edges represent the mean pose and opaque nodes/edges represent a perturbation in the direction of the PC. \n", + "- After fitting, edit `latent_dimension` in the config. This determines the dimension of the pose trajectory used to fit keypoint-MoSeq. A good heuristic is the number of dimensions needed to explain 90% of variance, or 10 dimensions - whichever is lower. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "respiratory-canvas", + "metadata": {}, + "outputs": [], + "source": [ + "pca = kpms.fit_pca(**data, **config())\n", + "kpms.save_pca(pca, project_dir)\n", + "\n", + "kpms.print_dims_to_explain_variance(pca, 0.9)\n", + "kpms.plot_scree(pca, project_dir=project_dir)\n", + "kpms.plot_pcs(pca, project_dir=project_dir, **config())\n", + "\n", + "# use the following to load an already fit model\n", + "# pca = kpms.load_pca(project_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1a3a9b6", + "metadata": {}, + "outputs": [], + "source": [ + "kpms.update_config(project_dir, latent_dim=4)" + ] + }, + { + "cell_type": "markdown", + "id": "accomplished-pantyhose", + "metadata": {}, + "source": [ + "# Model fitting\n", + "\n", + "Fitting a keypoint-MoSeq model involves:\n", + "1. **Initialization:** Auto-regressive (AR) parameters and syllable sequences are randomly initialized using pose trajectories from PCA.\n", + "2. **Fitting an AR-HMM:** The AR parameters, transition probabilities and syllable sequences are iteratively updated through Gibbs sampling. \n", + "3. **Fitting the full model:** All parameters, including both the AR-HMM as well as centroid, heading, noise-estimates and continuous latent states (i.e. pose trajectories) are iteratively updated through Gibbs sampling. This step is especially useful for noisy data.\n", + "4. **Extracting model results:** The learned states of the model are parsed and saved to disk for vizualization and downstream analysis.\n", + "4. **[Optional] Applying the trained model:** The learned model parameters can be used to infer a syllable sequences for additional data.\n", + "\n", + "## Setting kappa\n", + "\n", + "Most users will need to adjust the **kappa** hyperparameter to achieve the desired distribution of syllable durations. For this tutorial we chose kappa values that yielded a median syllable duration of 400ms (12 frames). Most users will need to tune kappa to their particular dataset. Higher values of kappa lead to longer syllables. **You will need to pick two kappas: one for AR-HMM fitting and one for the full model.**\n", + "- We recommend iteratively updating kappa and refitting the model until the target syllable time-scale is attained. \n", + "- Model fitting can be stopped at any time by interrupting the kernel, and then restarted with a new kappa value.\n", + "- The full model will generally require a lower value of kappa to yield the same target syllable durations. \n", + "- To adjust the value of kappa in the model, use `kpms.update_hypparams` as shown below. Note that this command only changes kappa in the model dictionary, not the kappa value in the config file. The value in the config is only used during model initialization." + ] + }, + { + "cell_type": "markdown", + "id": "utility-penetration", + "metadata": {}, + "source": [ + "## Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "found-administrator", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize the model\n", + "model = kpms.init_model(data, pca=pca, **config())\n", + "\n", + "# optionally modify kappa\n", + "# model = kpms.update_hypparams(model, kappa=NUMBER)" + ] + }, + { + "cell_type": "markdown", + "id": "partial-remove", + "metadata": {}, + "source": [ + "## Fitting an AR-HMM\n", + "\n", + "In addition to fitting an AR-HMM, the function below:\n", + "- generates a name for the model and a corresponding directory in `project_dir`\n", + "- saves a checkpoint every 25 iterations from which fitting can be restarted\n", + "- plots the progress of fitting every 25 iterations, including\n", + " - the distributions of syllable frequencies and durations for the most recent iteration\n", + " - the change in median syllable duration across fitting iterations\n", + " - a sample of the syllable sequence across iterations in a random window" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "888e6ff7", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "num_ar_iters = 50\n", + "\n", + "model, model_name = kpms.fit_model(\n", + " model, data, metadata, project_dir,\n", + " ar_only=True, num_iters=num_ar_iters)" + ] + }, + { + "cell_type": "markdown", + "id": "thorough-identity", + "metadata": {}, + "source": [ + "## Fitting the full model\n", + "\n", + "The following code fits a full keypoint-MoSeq model using the results of AR-HMM fitting for initialization. If using your own data, you may need to try a few values of kappa at this step. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "swiss-repeat", + "metadata": {}, + "outputs": [], + "source": [ + "# load model checkpoint\n", + "model, data, metadata, current_iter = kpms.load_checkpoint(\n", + " project_dir, model_name, iteration=num_ar_iters)\n", + "\n", + "# modify kappa to maintain the desired syllable time-scale\n", + "model = kpms.update_hypparams(model, kappa=1e4)\n", + "\n", + "# run fitting for an additional 500 iters\n", + "model = kpms.fit_model(\n", + " model, data, metadata, project_dir, model_name, ar_only=False, \n", + " start_iter=current_iter, num_iters=current_iter+500)[0]\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "0837e0ad", + "metadata": {}, + "source": [ + "## Sort syllables by frequency\n", + "\n", + "Permute the states and parameters of a saved checkpoint so that syllables are labeled in order of frequency (i.e. so that `0` is the most frequent, `1` is the second most, and so on). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "902ccabf", + "metadata": {}, + "outputs": [], + "source": [ + "# modify a saved checkpoint so syllables are ordered by frequency\n", + "kpms.reindex_syllables_in_checkpoint(project_dir, model_name);" + ] + }, + { + "cell_type": "markdown", + "id": "bc027d4a", + "metadata": {}, + "source": [ + "```{warning}\n", + "Reindexing is only applied to the checkpoint file. Therefore, if you perform this step after extracting the modeling results or generating vizualizations, then those steps must be repeated.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "79951b99", + "metadata": {}, + "source": [ + "## Extract model results\n", + "\n", + "Parse the modeling results and save them to `{project_dir}/{model_name}/results.h5`. The results are stored as follows, and can be reloaded at a later time using `kpms.load_results`. Check the docs for an [in-depth explanation of the modeling results](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#interpreting-model-outputs).\n", + "```\n", + " results.h5\n", + " ├──recording_name1\n", + " │ ├──syllable # syllable labels (z)\n", + " │ ├──latent_state # inferred low-dim pose state (x)\n", + " │ ├──centroid # inferred centroid (v)\n", + " │ └──heading # inferred heading (h)\n", + " ⋮\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8abffb5", + "metadata": {}, + "outputs": [], + "source": [ + "# load the most recent model checkpoint\n", + "model, data, metadata, current_iter = kpms.load_checkpoint(project_dir, model_name)\n", + "\n", + "# extract results\n", + "results = kpms.extract_results(model, metadata, project_dir, model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "a37f9d42", + "metadata": {}, + "source": [ + "### [Optional] Save results to csv\n", + "\n", + "After extracting to an h5 file, the results can also be saved as csv files. A separate file will be created for each recording and saved to `{project_dir}/{model_name}/results/`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ceea1e2", + "metadata": {}, + "outputs": [], + "source": [ + "# optionally save results as csv\n", + "kpms.save_results_as_csv(results, project_dir, model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "empty-houston", + "metadata": {}, + "source": [ + "## Apply to new data\n", + "\n", + "The code below shows how to apply a trained model to new data. This is useful if you have performed new experiments and would like to maintain an existing set of syllables. The results for the new experiments will be added to the existing `results.h5` file. **This step is optional and can be skipped if you do not have new data to add**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abae55fd", + "metadata": {}, + "outputs": [], + "source": [ + "# load the most recent model checkpoint and pca object\n", + "# model = kpms.load_checkpoint(project_dir, model_name)[0]\n", + "\n", + "# # load new data (e.g. from deeplabcut)\n", + "# new_data = 'path/to/new/data/' # can be a file, a directory, or a list of files\n", + "# coordinates, confidences, bodyparts = kpms.load_keypoints(new_data, 'deeplabcut')\n", + "# data, metadata = kpms.format_data(coordinates, confidences, **config())\n", + "\n", + "# # apply saved model to new data\n", + "# results = kpms.apply_model(model, data, metadata, project_dir, model_name, **config())\n", + "\n", + "# optionally rerun `save_results_as_csv` to export the new results\n", + "# kpms.save_results_as_csv(results, project_dir, model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "breeding-fashion", + "metadata": {}, + "source": [ + "# Visualization" + ] + }, + { + "cell_type": "markdown", + "id": "a2491a0d", + "metadata": {}, + "source": [ + "## Trajectory plots\n", + "Generate plots showing the median trajectory of poses associated with each given syllable. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "subject-disney", + "metadata": {}, + "outputs": [], + "source": [ + "results = kpms.load_results(project_dir, model_name)\n", + "kpms.generate_trajectory_plots(coordinates, results, project_dir, model_name, **config())" + ] + }, + { + "cell_type": "markdown", + "id": "617a66ed", + "metadata": {}, + "source": [ + "## Grid movies\n", + "Generate video clips showing examples of each syllable. \n", + "\n", + "*Note: the code below will only work with 2D data. For 3D data, see the [FAQ](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#making-grid-movies-for-3d-data).*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dominant-packet", + "metadata": {}, + "outputs": [], + "source": [ + "kpms.generate_grid_movies(results, project_dir, model_name, coordinates=coordinates, **config());" + ] + }, + { + "cell_type": "markdown", + "id": "d670667d", + "metadata": {}, + "source": [ + "## Syllable Dendrogram\n", + "Plot a dendrogram representing distances between each syllable's median trajectory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a324c4", + "metadata": {}, + "outputs": [], + "source": [ + "kpms.plot_similarity_dendrogram(coordinates, results, project_dir, model_name, **config())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "keypoint_moseq", + "language": "python", + "name": "keypoint_moseq" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "dlc_config = os.path.join(data_dir, 'config.yaml')\n", - "kpms.setup_project(project_dir, deeplabcut_config=dlc_config)" - ] - }, - { - "cell_type": "markdown", - "id": "d8967d49", - "metadata": {}, - "source": [ - "### Option 2: Setup from SLEAP" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c13b902", - "metadata": { - "mystnb": { - "code_prompt_hide": "Setup from SLEAP", - "code_prompt_show": "Setup from SLEAP" - }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "# choose a .h5 file for one of your recordings\n", - "# sleap_file = os.path.join(data_dir, 'SLEAP_FILE_NAME') \n", - "# kpms.setup_project(project_dir, sleap_file=sleap_file)" - ] - }, - { - "cell_type": "markdown", - "id": "b9e62b8e", - "metadata": {}, - "source": [ - "### Options 3: Manual setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d804ac5", - "metadata": { - "mystnb": { - "code_prompt_hide": "Custom setup", - "code_prompt_show": "Custom setup" - }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "# bodyparts=[\n", - "# 'tail', 'spine4', 'spine3', 'spine2', 'spine1',\n", - "# 'head', 'nose', 'right ear', 'left ear']\n", - "\n", - "# skeleton=[\n", - "# ['tail', 'spine4'],\n", - "# ['spine4', 'spine3'],\n", - "# ['spine3', 'spine2'],\n", - "# ['spine2', 'spine1'],\n", - "# ['spine1', 'head'],\n", - "# ['nose', 'head'],\n", - "# ['left ear', 'head'],\n", - "# ['right ear', 'head']]\n", - "\n", - "# video_dir = os.path.join(data_dir, 'videos')\n", - "\n", - "# kpms.setup_project(\n", - "# project_dir,\n", - "# video_dir=video_dir,\n", - "# bodyparts=bodyparts,\n", - "# skeleton=skeleton)" - ] - }, - { - "cell_type": "markdown", - "id": "gothic-viking", - "metadata": {}, - "source": [ - "## Edit the config file\n", - "\n", - "The config can be edited in a text editor or using the function `kpms.update_config`, as shown below. In general, the following parameters should be specified for each project:\n", - "\n", - "- `bodyparts` (name of each keypoint; automatically imported from SLEAP/DeepLabCut)\n", - "- `use_bodyparts` (subset of bodyparts to use for modeling, set to all bodyparts by default; for mice we recommend excluding the tail)\n", - "- `anterior_bodyparts` and `posterior_bodyparts` (used for rotational alignment)\n", - "- `video_dir` (directory with videos of each experiment)\n", - "\n", - "Edit the config as follows for the [example DeepLabCut dataset](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "theoretical-yahoo", - "metadata": {}, - "outputs": [], - "source": [ - "kpms.update_config(\n", - " project_dir,\n", - " video_dir=os.path.join(data_dir, 'videos'),\n", - " anterior_bodyparts=['nose'],\n", - " posterior_bodyparts=['spine4'],\n", - " use_bodyparts=[\n", - " 'spine4', 'spine3', 'spine2', 'spine1',\n", - " 'head', 'nose', 'right ear', 'left ear'])" - ] - }, - { - "cell_type": "markdown", - "id": "phantom-dating", - "metadata": {}, - "source": [ - "## Load data\n", - "\n", - "The code below shows how to load keypoint detections from DeepLabCut. To load other formats, replace `'deeplabcut'` in the example with one of `'sleap', 'anipose', 'sleap-anipose', 'nwb'`. For other formats, see the [FAQ](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#loading-keypoint-tracking-data)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "expressed-christian", - "metadata": {}, - "outputs": [], - "source": [ - "# load data (e.g. from DeepLabCut)\n", - "keypoint_data_path = os.path.join(data_dir, 'videos') # can be a file, a directory, or a list of files\n", - "coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut')\n", - "\n", - "# format data for modeling\n", - "data, metadata = kpms.format_data(coordinates, confidences, **config())" - ] - }, - { - "cell_type": "markdown", - "id": "processed-struggle", - "metadata": {}, - "source": [ - "## Calibration [disabled in colab]\n", - "\n", - "The purpose of calibration is to learn the relationship between error and keypoint confidence scores. The resulting regression coefficients (`slope` and `intercept`) are used during modeling to set the noise prior on a per-frame, per-keypoint basis. **This step is disabled in colab**. In any case it can safely be skipped since the default parameters are fine for most datasets. " - ] - }, - { - "cell_type": "markdown", - "id": "organizational-theorem", - "metadata": {}, - "source": [ - "## Fit PCA\n", - "\n", - "Run the cell below to fit a PCA model to aligned and centered keypoint coordinates.\n", - "\n", - "- The model is saved to ``{project_dir}/pca.p`` and can be reloaded using ``kpms.load_pca``. \n", - "- Two plots are generated: a cumulative [scree plot](https://en.wikipedia.org/wiki/Scree_plot) and a depiction of each PC, where translucent nodes/edges represent the mean pose and opaque nodes/edges represent a perturbation in the direction of the PC. \n", - "- After fitting, edit `latent_dimension` in the config. This determines the dimension of the pose trajectory used to fit keypoint-MoSeq. A good heuristic is the number of dimensions needed to explain 90% of variance, or 10 dimensions - whichever is lower. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "respiratory-canvas", - "metadata": {}, - "outputs": [], - "source": [ - "pca = kpms.fit_pca(**data, **config())\n", - "kpms.save_pca(pca, project_dir)\n", - "\n", - "kpms.print_dims_to_explain_variance(pca, 0.9)\n", - "kpms.plot_scree(pca, project_dir=project_dir)\n", - "kpms.plot_pcs(pca, project_dir=project_dir, **config())\n", - "\n", - "# use the following to load an already fit model\n", - "# pca = kpms.load_pca(project_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a1a3a9b6", - "metadata": {}, - "outputs": [], - "source": [ - "kpms.update_config(project_dir, latent_dim=4)" - ] - }, - { - "cell_type": "markdown", - "id": "accomplished-pantyhose", - "metadata": {}, - "source": [ - "# Model fitting\n", - "\n", - "Fitting a keypoint-MoSeq model involves:\n", - "1. **Initialization:** Auto-regressive (AR) parameters and syllable sequences are randomly initialized using pose trajectories from PCA.\n", - "2. **Fitting an AR-HMM:** The AR parameters, transition probabilities and syllable sequences are iteratively updated through Gibbs sampling. \n", - "3. **Fitting the full model:** All parameters, including both the AR-HMM as well as centroid, heading, noise-estimates and continuous latent states (i.e. pose trajectories) are iteratively updated through Gibbs sampling. This step is especially useful for noisy data.\n", - "4. **Extracting model results:** The learned states of the model are parsed and saved to disk for vizualization and downstream analysis.\n", - "4. **[Optional] Applying the trained model:** The learned model parameters can be used to infer a syllable sequences for additional data.\n", - "\n", - "## Setting kappa\n", - "\n", - "Most users will need to adjust the **kappa** hyperparameter to achieve the desired distribution of syllable durations. For this tutorial we chose kappa values that yielded a median syllable duration of 400ms (12 frames). Most users will need to tune kappa to their particular dataset. Higher values of kappa lead to longer syllables. **You will need to pick two kappas: one for AR-HMM fitting and one for the full model.**\n", - "- We recommend iteratively updating kappa and refitting the model until the target syllable time-scale is attained. \n", - "- Model fitting can be stopped at any time by interrupting the kernel, and then restarted with a new kappa value.\n", - "- The full model will generally require a lower value of kappa to yield the same target syllable durations. \n", - "- To adjust the value of kappa in the model, use `kpms.update_hypparams` as shown below. Note that this command only changes kappa in the model dictionary, not the kappa value in the config file. The value in the config is only used during model initialization." - ] - }, - { - "cell_type": "markdown", - "id": "utility-penetration", - "metadata": {}, - "source": [ - "## Initialization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "found-administrator", - "metadata": {}, - "outputs": [], - "source": [ - "# initialize the model\n", - "model = kpms.init_model(data, pca=pca, **config())\n", - "\n", - "# optionally modify kappa\n", - "# model = kpms.update_hypparams(model, kappa=NUMBER)" - ] - }, - { - "cell_type": "markdown", - "id": "partial-remove", - "metadata": {}, - "source": [ - "## Fitting an AR-HMM\n", - "\n", - "In addition to fitting an AR-HMM, the function below:\n", - "- generates a name for the model and a corresponding directory in `project_dir`\n", - "- saves a checkpoint every 25 iterations from which fitting can be restarted\n", - "- plots the progress of fitting every 25 iterations, including\n", - " - the distributions of syllable frequencies and durations for the most recent iteration\n", - " - the change in median syllable duration across fitting iterations\n", - " - a sample of the syllable sequence across iterations in a random window" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "888e6ff7", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "num_ar_iters = 50\n", - "\n", - "model, model_name = kpms.fit_model(\n", - " model, data, metadata, project_dir,\n", - " ar_only=True, num_iters=num_ar_iters)" - ] - }, - { - "cell_type": "markdown", - "id": "thorough-identity", - "metadata": {}, - "source": [ - "## Fitting the full model\n", - "\n", - "The following code fits a full keypoint-MoSeq model using the results of AR-HMM fitting for initialization. If using your own data, you may need to try a few values of kappa at this step. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "swiss-repeat", - "metadata": {}, - "outputs": [], - "source": [ - "# load model checkpoint\n", - "model, data, metadata, current_iter = kpms.load_checkpoint(\n", - " project_dir, model_name, iteration=num_ar_iters)\n", - "\n", - "# modify kappa to maintain the desired syllable time-scale\n", - "model = kpms.update_hypparams(model, kappa=1e4)\n", - "\n", - "# run fitting for an additional 500 iters\n", - "model = kpms.fit_model(\n", - " model, data, metadata, project_dir, model_name, ar_only=False, \n", - " start_iter=current_iter, num_iters=current_iter+500)[0]\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "0837e0ad", - "metadata": {}, - "source": [ - "## Sort syllables by frequency\n", - "\n", - "Permute the states and parameters of a saved checkpoint so that syllables are labeled in order of frequency (i.e. so that `0` is the most frequent, `1` is the second most, and so on). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "902ccabf", - "metadata": {}, - "outputs": [], - "source": [ - "# modify a saved checkpoint so syllables are ordered by frequency\n", - "kpms.reindex_syllables_in_checkpoint(project_dir, model_name);" - ] - }, - { - "cell_type": "markdown", - "id": "bc027d4a", - "metadata": {}, - "source": [ - "```{warning}\n", - "Reindexing is only applied to the checkpoint file. Therefore, if you perform this step after extracting the modeling results or generating vizualizations, then those steps must be repeated.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "79951b99", - "metadata": {}, - "source": [ - "## Extract model results\n", - "\n", - "Parse the modeling results and save them to `{project_dir}/{model_name}/results.h5`. The results are stored as follows, and can be reloaded at a later time using `kpms.load_results`. Check the docs for an [in-depth explanation of the modeling results](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#interpreting-model-outputs).\n", - "```\n", - " results.h5\n", - " ├──recording_name1\n", - " │ ├──syllable # syllable labels (z)\n", - " │ ├──latent_state # inferred low-dim pose state (x)\n", - " │ ├──centroid # inferred centroid (v)\n", - " │ └──heading # inferred heading (h)\n", - " ⋮\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d8abffb5", - "metadata": {}, - "outputs": [], - "source": [ - "# load the most recent model checkpoint\n", - "model, data, metadata, current_iter = kpms.load_checkpoint(project_dir, model_name)\n", - "\n", - "# extract results\n", - "results = kpms.extract_results(model, metadata, project_dir, model_name)" - ] - }, - { - "cell_type": "markdown", - "id": "a37f9d42", - "metadata": {}, - "source": [ - "### [Optional] Save results to csv\n", - "\n", - "After extracting to an h5 file, the results can also be saved as csv files. A separate file will be created for each recording and saved to `{project_dir}/{model_name}/results/`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ceea1e2", - "metadata": {}, - "outputs": [], - "source": [ - "# optionally save results as csv\n", - "kpms.save_results_as_csv(results, project_dir, model_name)" - ] - }, - { - "cell_type": "markdown", - "id": "empty-houston", - "metadata": {}, - "source": [ - "## Apply to new data\n", - "\n", - "The code below shows how to apply a trained model to new data. This is useful if you have performed new experiments and would like to maintain an existing set of syllables. The results for the new experiments will be added to the existing `results.h5` file. **This step is optional and can be skipped if you do not have new data to add**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "abae55fd", - "metadata": {}, - "outputs": [], - "source": [ - "# load the most recent model checkpoint and pca object\n", - "# model = kpms.load_checkpoint(project_dir, model_name)[0]\n", - "\n", - "# # load new data (e.g. from deeplabcut)\n", - "# new_data = 'path/to/new/data/' # can be a file, a directory, or a list of files\n", - "# coordinates, confidences, bodyparts = kpms.load_keypoints(new_data, 'deeplabcut')\n", - "# data, metadata = kpms.format_data(coordinates, confidences, **config())\n", - "\n", - "# # apply saved model to new data\n", - "# results = kpms.apply_model(model, data, metadata, project_dir, model_name, **config())\n", - "\n", - "# optionally rerun `save_results_as_csv` to export the new results\n", - "# kpms.save_results_as_csv(results, project_dir, model_name)" - ] - }, - { - "cell_type": "markdown", - "id": "breeding-fashion", - "metadata": {}, - "source": [ - "# Visualization" - ] - }, - { - "cell_type": "markdown", - "id": "a2491a0d", - "metadata": {}, - "source": [ - "## Trajectory plots\n", - "Generate plots showing the median trajectory of poses associated with each given syllable. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "subject-disney", - "metadata": {}, - "outputs": [], - "source": [ - "results = kpms.load_results(project_dir, model_name)\n", - "kpms.generate_trajectory_plots(coordinates, results, project_dir, model_name, **config())" - ] - }, - { - "cell_type": "markdown", - "id": "617a66ed", - "metadata": {}, - "source": [ - "## Grid movies\n", - "Generate video clips showing examples of each syllable. \n", - "\n", - "*Note: the code below will only work with 2D data. For 3D data, see the [FAQ](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#making-grid-movies-for-3d-data).*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dominant-packet", - "metadata": {}, - "outputs": [], - "source": [ - "kpms.generate_grid_movies(results, project_dir, model_name, coordinates=coordinates, **config());" - ] - }, - { - "cell_type": "markdown", - "id": "d670667d", - "metadata": {}, - "source": [ - "## Syllable Dendrogram\n", - "Plot a dendrogram representing distances between each syllable's median trajectory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "81a324c4", - "metadata": {}, - "outputs": [], - "source": [ - "kpms.plot_similarity_dendrogram(coordinates, results, project_dir, model_name, **config())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "keypoint_moseq", - "language": "python", - "name": "keypoint_moseq" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - }, - "varInspector": { - "cols": { - "lenName": 16, - "lenType": 16, - "lenVar": 40 - }, - "kernels_config": { - "python": { - "delete_cmd_postfix": "", - "delete_cmd_prefix": "del ", - "library": "var_list.py", - "varRefreshCmd": "print(var_dic_list())" - }, - "r": { - "delete_cmd_postfix": ") ", - "delete_cmd_prefix": "rm(", - "library": "var_list.r", - "varRefreshCmd": "cat(var_dic_list()) " - } - }, - "types_to_exclude": [ - "module", - "function", - "builtin_function_or_method", - "instance", - "_Feature" - ], - "window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/docs/source/install.rst b/docs/source/install.rst index 1f50bf0..0334fa6 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -1,5 +1,5 @@ Local installation ------------------- +================== - Total installation time is around 10 minutes. - The first import of keypoint_moseq after installation can take a few minutes. @@ -11,9 +11,7 @@ Local installation Install using conda -~~~~~~~~~~~~~~~~~~~ - - +------------------- Use conda environment files to automatically install the appropriate GPU drivers and other dependencies. Start by cloning the repository:: @@ -53,7 +51,7 @@ To run keypoint-moseq in jupyter, either launch jupyterlab directly from the `ke Install using pip -~~~~~~~~~~~~~~~~~ +----------------- .. note:: @@ -64,19 +62,51 @@ Create a new conda environment with python 3.9:: conda create -n keypoint_moseq python=3.9 conda activate keypoint_moseq -Install jax using one of the lines below:: +Next install jax and jax-moseq using one of the options below. - # MacOS or Linux (CPU) - pip install "jax[cpu]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html +1. **CPU only** - # MacOS or Linux (GPU with CUDA 11.X) - pip install "jax[cuda11_cudnn82]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + Install jax for your operating-system:: - # Windows (CPU) - pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cpu/jaxlib-0.3.22-cp39-cp39-win_amd64.whl + # MacOS or Linux (CPU) + pip install "jax[cpu]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html + + # Windows (CPU) + pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cpu/jaxlib-0.3.22-cp39-cp39-win_amd64.whl + + Install jax-moseq:: + + pip install jax-moseq + + +2. **GPU with CUDA 11** + + Install jax for your operating-system:: + + # MacOS or Linux (GPU with CUDA 11) + pip install "jax[cuda11_cudnn82]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + + + # Windows (GPU with CUDA 11) + pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl + + Install jax-moseq:: + + pip install jax-moseq[cuda11] + + +3. **GPU with CUDA 12** + + This option assumes that you already have a working installation of jax that is compatible with CUDA 12. Among other things, the following code should run without error:: + + import jax + jax.random.PRNGKey(0) + + + Install jax-moseq:: + + pip install jax-moseq[cuda12] - # Windows (GPU) - pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl Install `keypoint-moseq `_:: diff --git a/setup.cfg b/setup.cfg index a03eb0e..b234acf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,7 +14,7 @@ python_requires = >=3.9 install_requires = seaborn==0.13.0 cytoolz - matplotlib + matplotlib==3.8.4 tqdm ipykernel imageio[ffmpeg] @@ -25,6 +25,9 @@ install_requires = bokeh pandas tables + bokeh==2.4.3 + panel==0.14.4 + holoviews==1.15.4 networkx sleap_io pynwb @@ -33,10 +36,15 @@ install_requires = ipython_genutils tabulate commentjson - jaxtyping==0.2.14 - etils==1.5.2 - scipy==1.11.3 - jax-moseq==0.2.2 + jax-moseq>=0.2.2 + numpy<=1.26.4 + +[options.extras_require] +dev = + sphinx + sphinx-rtd-theme + autodocsumm + myst-nb [options.package_data] * = *.md