Skip to content

Commit

Permalink
Merge pull request #20 from parrt/jax
Browse files Browse the repository at this point in the history
Add JAX support
  • Loading branch information
parrt authored Dec 3, 2020
2 parents c1da933 + e5ed697 commit 8694be5
Show file tree
Hide file tree
Showing 7 changed files with 1,637 additions and 433 deletions.
21 changes: 17 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Tensor Sensor

<img src="https://explained.ai/tensor-sensor/images/teaser.png" width="50%" align="right">One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefined [Tensorflow](https://www.tensorflow.org/) network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages. To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/).
<img src="https://explained.ai/tensor-sensor/images/teaser.png" width="50%" align="right">One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefined [Tensorflow](https://www.tensorflow.org/) network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages.

To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [JAX](https://github.com/google/jax), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/).

Please read the complete description in article [Clarifying exceptions and visualizing tensor operations in deep learning code](https://explained.ai/tensor-sensor/index.html).

Expand Down Expand Up @@ -35,14 +37,24 @@ TensorSensor augments the message with more information about which operator cau
Cause: @ on tensor operand W w/shape [764, 100] and operand X.T w/shape [764, 200]
```

You can also get the full computation graph for an expression that includes all of these sub result shapes.

```python
tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", sys._getframe())
```

yields the following abstract syntax tree with shapes:

<img src="images/ast.svg" width="400">

## Install

```
pip install tensor-sensor # This will only install the library for you
pip install tensor-sensor[torch] # install pytorch related dependency
pip install tensor-sensor[tensorflow] # install tensorflow related dependency
pip install tensor-sensor[all] # install both tensorflow and pytorch
pip install tensor-sensor[jax] # install jax, jaxlib
pip install tensor-sensor[all] # install tensorflow, pytorch, jax
```

which gives you module `tsensor`. I developed and tested with the following versions
Expand All @@ -56,6 +68,9 @@ numpy 1.18.5
numpydoc 1.1.0
$ pip list | grep -i torch
torch 1.6.0
$ pip list | grep -i jax
jax 0.2.6
jaxlib 0.1.57
```

### Graphviz for tsensor.astviz()
Expand Down Expand Up @@ -115,5 +130,3 @@ $ pip install .
### TODO

* can i call pyviz in debugger?
* try on real examples
* `dict(W=[3,0,1,2], b=[1,0])` that would indicate (300, 30, 60, 3) would best be displayed as (30,60,3, 300) and b would be first dimension last and last dimension first
Loading

0 comments on commit 8694be5

Please sign in to comment.