-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
skeleton for optional dependencies impleted for jax - see #9
- Loading branch information
Showing
3 changed files
with
86 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import unittest | ||
import torchsparsegradutils as tsgu | ||
import torchsparsegradutils.jax as tsgujax | ||
|
||
if not tsgujax.have_jax: | ||
raise unittest.SkipTest("Importing optional jax-related module failed to find jax -> skipping jax-related tests.") | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
class J2TIOTest(unittest.TestCase): | ||
"""IO conversion tests between torch and jax""" | ||
|
||
def setUp(self) -> None: | ||
# The device can be specialised by a daughter class | ||
if not hasattr(self, "device"): | ||
self.device = torch.device("cpu") | ||
self.device_j = jax.devices("cpu")[0] | ||
self.x_shape = (4,4) | ||
self.x_t = torch.randn(self.x_shape, dtype=torch.float64, device=self.device) | ||
self.x_j = jax.device_put(np.random.randn(self.x_shape[0],self.x_shape[1]), device=self.device_j) | ||
#print(self.x_j.device_buffer.device()) | ||
|
||
def test_t2j(self): | ||
x_j = tsgujax.t2j(self.x_t) | ||
self.assertTrue( x_j.shape == self.x_t.shape ) | ||
self.assertTrue( np.isclose( np.asarray(x_j), self.x_t.numpy() ).all() ) | ||
x_t2j2t = tsgujax.j2t(x_j) | ||
self.assertTrue( x_t2j2t.shape == self.x_t.shape ) | ||
self.assertTrue( np.isclose( x_t2j2t.numpy(), self.x_t.numpy() ).all() ) | ||
|
||
def test_j2t(self): | ||
x_t = tsgujax.j2t(self.x_j) | ||
self.assertTrue( x_t.shape == self.x_j.shape ) | ||
self.assertTrue( np.isclose( np.asarray(self.x_j), x_t.numpy() ).all() ) | ||
x_j2t2j = tsgujax.t2j(x_t) | ||
self.assertTrue( x_j2t2j.shape == self.x_j.shape ) | ||
self.assertTrue( np.isclose( np.asarray(x_j2t2j), np.asarray(self.x_j) ).all() ) | ||
|
||
|
||
class J2TIOTestCUDA(J2TIOTest): | ||
"""Override superclass setUp to run on GPU""" | ||
|
||
def setUp(self) -> None: | ||
if not torch.cuda.is_available(): | ||
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") | ||
self.device = torch.device("cuda") | ||
self.device_j = jax.devices("gpu")[0] | ||
super().setUp() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import importlib | ||
jax_spec = importlib.util.find_spec("jax") | ||
if jax_spec is None: | ||
have_jax = False | ||
import warnings | ||
warnings.warn("\n\nAttempting to import an optional module in torchsparsegradutils that depends on jax but jax couldn't be imported.\n") | ||
else: | ||
have_jax = True | ||
from .jax_bindings import j2t | ||
from .jax_bindings import t2j | ||
__all__ = ["j2t", "t2j"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import jax | ||
from jax import dlpack as jax_dlpack | ||
|
||
import torch | ||
from torch.utils import dlpack as torch_dlpack | ||
|
||
def j2t(x_jax): | ||
# Convert a jax array to a torch tensor | ||
# See https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py | ||
x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax)) | ||
return x_torch | ||
|
||
def t2j(x_torch): | ||
# Convert a torch tensor to a jax array | ||
# See https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py | ||
x_torch = x_torch.contiguous() # https://github.com/google/jax/issues/8082 | ||
x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch)) | ||
return x_jax |