From a614a500a5b0f8986c9b1da737da4923518ed4eb Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Fri, 19 Jul 2024 13:17:54 -0400 Subject: [PATCH 1/8] bump jax-moseq dep and remove pinned deps related to modeling --- setup.cfg | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index a03eb0e..258aafc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,7 +14,7 @@ python_requires = >=3.9 install_requires = seaborn==0.13.0 cytoolz - matplotlib + matplotlib==3.8.4 tqdm ipykernel imageio[ffmpeg] @@ -33,10 +33,7 @@ install_requires = ipython_genutils tabulate commentjson - jaxtyping==0.2.14 - etils==1.5.2 - scipy==1.11.3 - jax-moseq==0.2.2 + jax-moseq==0.2.3 [options.package_data] * = *.md From b96cfc2f0d2ed58207a1a33545627967277d2a4d Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Fri, 19 Jul 2024 14:29:39 -0400 Subject: [PATCH 2/8] instal cuda11 version of jax-moseq when using conda env files --- conda_envs/environment.linux_gpu.yml | 1 + conda_envs/environment.win64_gpu.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/conda_envs/environment.linux_gpu.yml b/conda_envs/environment.linux_gpu.yml index dca1766..544f77c 100644 --- a/conda_envs/environment.linux_gpu.yml +++ b/conda_envs/environment.linux_gpu.yml @@ -14,5 +14,6 @@ dependencies: - pip - pip: - "keypoint-moseq" + - "jax-moseq[cuda11]" - jupyterlab - etils==1.5.2 \ No newline at end of file diff --git a/conda_envs/environment.win64_gpu.yml b/conda_envs/environment.win64_gpu.yml index e9ff628..5eca66b 100644 --- a/conda_envs/environment.win64_gpu.yml +++ b/conda_envs/environment.win64_gpu.yml @@ -16,4 +16,5 @@ dependencies: - jax==0.3.22 - "https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl" - "keypoint-moseq" + - "jax-moseq[cuda11]" - jupyterlab From 1e28437bd3c2f3e1206ee8c36b2c3c27993c068f Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Fri, 19 Jul 2024 16:05:24 -0400 Subject: [PATCH 3/8] unpin jax-moseq version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 258aafc..c6f98ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ install_requires = ipython_genutils tabulate commentjson - jax-moseq==0.2.3 + jax-moseq>=0.2.2 [options.package_data] * = *.md From 8b86dd3089344dbfc2fd96cdf765682d57764d3e Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Wed, 18 Sep 2024 17:00:22 -0400 Subject: [PATCH 4/8] fix panel import bug --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index c6f98ec..4b8c25d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,9 @@ install_requires = bokeh pandas tables + bokeh==2.4.3 + panel==0.14.4 + holoviews==1.15.4 networkx sleap_io pynwb From ab1379862329c5e8ebc5b6467999ca007f97363f Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Fri, 20 Sep 2024 14:26:17 -0400 Subject: [PATCH 5/8] add manual jax-moseq install to docs and update colab --- docs/keypoint_moseq_colab.ipynb | 1 + docs/source/install.rst | 58 +++++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/docs/keypoint_moseq_colab.ipynb b/docs/keypoint_moseq_colab.ipynb index 8ce7771..34791ad 100644 --- a/docs/keypoint_moseq_colab.ipynb +++ b/docs/keypoint_moseq_colab.ipynb @@ -33,6 +33,7 @@ "! apt update && apt install cuda-11-8\n", "! pip install tensorflow==2.12.0\n", "! pip install --upgrade \"jax[cuda]==0.3.22\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "! pip install jax-moseq[cuda11]\n", "! pip install keypoint-moseq\n", "\n", "import os\n", diff --git a/docs/source/install.rst b/docs/source/install.rst index 1f50bf0..0334fa6 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -1,5 +1,5 @@ Local installation ------------------- +================== - Total installation time is around 10 minutes. - The first import of keypoint_moseq after installation can take a few minutes. @@ -11,9 +11,7 @@ Local installation Install using conda -~~~~~~~~~~~~~~~~~~~ - - +------------------- Use conda environment files to automatically install the appropriate GPU drivers and other dependencies. Start by cloning the repository:: @@ -53,7 +51,7 @@ To run keypoint-moseq in jupyter, either launch jupyterlab directly from the `ke Install using pip -~~~~~~~~~~~~~~~~~ +----------------- .. note:: @@ -64,19 +62,51 @@ Create a new conda environment with python 3.9:: conda create -n keypoint_moseq python=3.9 conda activate keypoint_moseq -Install jax using one of the lines below:: +Next install jax and jax-moseq using one of the options below. - # MacOS or Linux (CPU) - pip install "jax[cpu]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html +1. **CPU only** - # MacOS or Linux (GPU with CUDA 11.X) - pip install "jax[cuda11_cudnn82]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + Install jax for your operating-system:: - # Windows (CPU) - pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cpu/jaxlib-0.3.22-cp39-cp39-win_amd64.whl + # MacOS or Linux (CPU) + pip install "jax[cpu]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html + + # Windows (CPU) + pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cpu/jaxlib-0.3.22-cp39-cp39-win_amd64.whl + + Install jax-moseq:: + + pip install jax-moseq + + +2. **GPU with CUDA 11** + + Install jax for your operating-system:: + + # MacOS or Linux (GPU with CUDA 11) + pip install "jax[cuda11_cudnn82]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + + + # Windows (GPU with CUDA 11) + pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl + + Install jax-moseq:: + + pip install jax-moseq[cuda11] + + +3. **GPU with CUDA 12** + + This option assumes that you already have a working installation of jax that is compatible with CUDA 12. Among other things, the following code should run without error:: + + import jax + jax.random.PRNGKey(0) + + + Install jax-moseq:: + + pip install jax-moseq[cuda12] - # Windows (GPU) - pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl Install `keypoint-moseq `_:: From db464e54c5627f8d91bb9fd0fcd901b3dec0b3f5 Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Wed, 25 Sep 2024 14:54:30 -0400 Subject: [PATCH 6/8] add dev install option --- setup.cfg | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/setup.cfg b/setup.cfg index 4b8c25d..74e1963 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,13 @@ install_requires = commentjson jax-moseq>=0.2.2 +[options.extras_require] +dev = + sphinx + sphinx-rtd-theme + autodocsumm + myst-nb + [options.package_data] * = *.md From 84bb86e574a63f2d2e9e55fb55eb2a72f82bf3f3 Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Wed, 25 Sep 2024 14:57:47 -0400 Subject: [PATCH 7/8] fix docs typo --- docs/source/FAQs.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/FAQs.rst b/docs/source/FAQs.rst index 2067f2a..4e0afff 100644 --- a/docs/source/FAQs.rst +++ b/docs/source/FAQs.rst @@ -195,7 +195,7 @@ The final output of keypoint MoSeq is a results .h5 file (and optionally a direc Validating results when applying a model to new data ---------------------------------------------------- +---------------------------------------------------- When applying a model to new data, it may be useful to generate new grid movies and trajectory plots so you can confirm that the meaning of the syllables has been preserved. Let's say you've already applied the model to new data as follows: .. code-block:: python From e91ef9e85631591da5586134770d017d1ff5ec9d Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Fri, 18 Oct 2024 15:52:07 -0400 Subject: [PATCH 8/8] prevent numpy 2 --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 74e1963..b234acf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ install_requires = tabulate commentjson jax-moseq>=0.2.2 + numpy<=1.26.4 [options.extras_require] dev =