Skip to content

Commit

Permalink
skeleton for optional dependencies impleted for jax - see #9
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Nov 22, 2022
1 parent 8521f36 commit 5f1b803
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/test_jax_bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import unittest
import torchsparsegradutils as tsgu
import torchsparsegradutils.jax as tsgujax

if not tsgujax.have_jax:
raise unittest.SkipTest("Importing optional jax-related module failed to find jax -> skipping jax-related tests.")

import numpy as np

import jax
import jax.numpy as jnp

class J2TIOTest(unittest.TestCase):
"""IO conversion tests between torch and jax"""

def setUp(self) -> None:
# The device can be specialised by a daughter class
if not hasattr(self, "device"):
self.device = torch.device("cpu")
self.device_j = jax.devices("cpu")[0]
self.x_shape = (4,4)
self.x_t = torch.randn(self.x_shape, dtype=torch.float64, device=self.device)
self.x_j = jax.device_put(np.random.randn(self.x_shape[0],self.x_shape[1]), device=self.device_j)
#print(self.x_j.device_buffer.device())

def test_t2j(self):
x_j = tsgujax.t2j(self.x_t)
self.assertTrue( x_j.shape == self.x_t.shape )
self.assertTrue( np.isclose( np.asarray(x_j), self.x_t.numpy() ).all() )
x_t2j2t = tsgujax.j2t(x_j)
self.assertTrue( x_t2j2t.shape == self.x_t.shape )
self.assertTrue( np.isclose( x_t2j2t.numpy(), self.x_t.numpy() ).all() )

def test_j2t(self):
x_t = tsgujax.j2t(self.x_j)
self.assertTrue( x_t.shape == self.x_j.shape )
self.assertTrue( np.isclose( np.asarray(self.x_j), x_t.numpy() ).all() )
x_j2t2j = tsgujax.t2j(x_t)
self.assertTrue( x_j2t2j.shape == self.x_j.shape )
self.assertTrue( np.isclose( np.asarray(x_j2t2j), np.asarray(self.x_j) ).all() )


class J2TIOTestCUDA(J2TIOTest):
"""Override superclass setUp to run on GPU"""

def setUp(self) -> None:
if not torch.cuda.is_available():
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")
self.device = torch.device("cuda")
self.device_j = jax.devices("gpu")[0]
super().setUp()


if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions torchsparsegradutils/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import importlib
jax_spec = importlib.util.find_spec("jax")
if jax_spec is None:
have_jax = False
import warnings
warnings.warn("\n\nAttempting to import an optional module in torchsparsegradutils that depends on jax but jax couldn't be imported.\n")
else:
have_jax = True
from .jax_bindings import j2t
from .jax_bindings import t2j
__all__ = ["j2t", "t2j"]

18 changes: 18 additions & 0 deletions torchsparsegradutils/jax/jax_bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax
from jax import dlpack as jax_dlpack

import torch
from torch.utils import dlpack as torch_dlpack

def j2t(x_jax):
# Convert a jax array to a torch tensor
# See https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py
x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
return x_torch

def t2j(x_torch):
# Convert a torch tensor to a jax array
# See https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py
x_torch = x_torch.contiguous() # https://github.com/google/jax/issues/8082
x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
return x_jax

0 comments on commit 5f1b803

Please sign in to comment.