diff --git a/noxfile.py b/noxfile.py index a46d9cc..ad4811b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,6 +15,11 @@ def lint(session): @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"]) def tests(session): + session.install( + 'torch==2.2.1', + 'torchvision', + '--index-url', 'https://download.pytorch.org/whl/cpu' + ) session.install('.') session.install('pytest') session.install('pytest-mock')