diff --git a/testing/examples.ipynb b/testing/examples.ipynb new file mode 100644 index 0000000..e7b0d22 --- /dev/null +++ b/testing/examples.ipynb @@ -0,0 +1,1057 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clarifying the cause of tensor exceptions" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import tsensor\n", + "import graphviz\n", + "import torch\n", + "import sys\n", + "\n", + "W = torch.tensor([[1, 2], [3, 4]])\n", + "b = torch.tensor([9, 10]).reshape(2, 1)\n", + "x = torch.tensor([4, 5]).reshape(2, 1)\n", + "h = torch.tensor([1,2])" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PyTorch says: 1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\n", + "\n", + "tsensor adds: Cause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]\n" + ] + } + ], + "source": [ + "try: # try is used just to catch the exception and extract the messages\n", + " with tsensor.clarify():\n", + " W @ torch.dot(b,b)+ torch.eye(2,2)@x + z\n", + "except BaseException as e:\n", + " msgs = e.args[0].split(\"\\n\")\n", + " sys.stderr.write(\"PyTorch says: \"+msgs[0]+'\\n\\n')\n", + " sys.stderr.write(\"tsensor adds: \"+msgs[1]+'\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PyTorch says: 1D tensors expected, got 2D, 1D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\n", + "\n", + "tsensor adds: Cause: W.dot(h) tensor arg h w/shape [2]\n" + ] + } + ], + "source": [ + "try:\n", + " with tsensor.clarify():\n", + " W.dot(h) + x\n", + "except BaseException as e:\n", + " msgs = e.args[0].split(\"\\n\")\n", + " sys.stderr.write(\"PyTorch says: \"+msgs[0]+'\\n\\n')\n", + " sys.stderr.write(\"tsensor adds: \"+msgs[1]+'\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Explaining matrix algebra statements visually" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "G\n", + "\n", + "\n", + "\n", + "leaf140358874896368\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "a\n", + "    \n", + "\n", + "\n", + "\n", + "leaf140358874898096\n", + "\n", + "=\n", + "\n", + "\n", + "\n", + "\n", + "leaf140358874895840\n", + "\n", + "torch\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532779760\n", + "\n", + ".\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532782400\n", + "\n", + "relu\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532780384\n", + "\n", + "(\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532782352\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "x\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532778560\n", + "\n", + ")\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "G\n", + "\n", + "\n", + "\n", + "leaf140357532780480\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "b\n", + "    \n", + "\n", + "\n", + "\n", + "leaf140357532778800\n", + "\n", + "=\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532781488\n", + "\n", + "    \n", + "2x2\n", + "    \n", + "\n", + "\n", + "    \n", + "W\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532779712\n", + "\n", + "@\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532780432\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "b\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532780864\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532781584\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "x\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532781632\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532781680\n", + "\n", + "3\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532778656\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532780048\n", + "\n", + "    \n", + "2\n", + "    \n", + "\n", + "\n", + "    \n", + "h\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532779280\n", + "\n", + ".\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532782064\n", + "\n", + "dot\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532782304\n", + "\n", + "(\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532779088\n", + "\n", + "    \n", + "2\n", + "    \n", + "\n", + "\n", + "    \n", + "h\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357532781776\n", + "\n", + ")\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with tsensor.explain():\n", + " a = torch.relu(x)\n", + " b = W @ b + x * 3 + h.dot(h)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Saving explanations to files" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "with tsensor.explain(savefig=\"/tmp/foo\"): # save foo-1.svg and foo-2.svg in /tmp\n", + " a = torch.relu(x)\n", + " b = W @ b + x * 3 + h.dot(h)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-rw-r--r--@ 1 parrt wheel 5004 Sep 9 17:53 /tmp/foo-1.svg\n", + "-rw-r--r--@ 1 parrt wheel 11068 Sep 9 17:53 /tmp/foo-2.svg\n" + ] + } + ], + "source": [ + "!ls -l /tmp/foo-?.svg" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "G\n", + "\n", + "\n", + "\n", + "leaf140357531149072\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "b\n", + "    \n", + "\n", + "\n", + "\n", + "leaf140357531149168\n", + "\n", + "=\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149264\n", + "\n", + "    \n", + "2x2\n", + "    \n", + "\n", + "\n", + "    \n", + "W\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149360\n", + "\n", + "@\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149456\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "b\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149552\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149648\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "x\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149744\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149840\n", + "\n", + "3\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531149936\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531150032\n", + "\n", + "    \n", + "2\n", + "    \n", + "\n", + "\n", + "    \n", + "h\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531150128\n", + "\n", + ".\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531150224\n", + "\n", + "dot\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531150320\n", + "\n", + "(\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531150416\n", + "\n", + "    \n", + "2\n", + "    \n", + "\n", + "\n", + "    \n", + "h\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357531150512\n", + "\n", + ")\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import SVG\n", + "display(SVG(\"/tmp/foo-2.svg\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Lower-level API to show abstract parse trees and evaluate them" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "frame = sys._getframe() # where are we executing right now" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Get ast and computation result" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[237],\n", + " [506]])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "root, result = tsensor.eval(\"W @ b + x * 3 + h.dot(h)\", frame)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Show the ast" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "G\n", + "\n", + "\n", + "\n", + "leaf140357528650176\n", + "\n", + "W\n", + "\n", + "\n", + "\n", + "leaf140357528649888\n", + "\n", + "@\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253350400\n", + "\n", + "b\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253350592\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253350976\n", + "\n", + "x\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253349728\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253349488\n", + "\n", + "3\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253506336\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253505616\n", + "\n", + "h\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253507776\n", + "\n", + ".\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253507824\n", + "\n", + "dot\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253507920\n", + "\n", + "(\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253508592\n", + "\n", + "h\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253508016\n", + "\n", + ")\n", + "\n", + "\n", + "\n", + "\n", + "node140357528652000\n", + "\n", + "@\n", + "\n", + "\n", + "\n", + "node140357528652000->leaf140357528650176\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357528652000->leaf140357253350400\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357253508352\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "node140357253508352->leaf140357253350976\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357253508352->leaf140357253349488\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530824176\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "node140357530824176->node140357528652000\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530824176->node140357253508352\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530823264\n", + "\n", + ".\n", + "\n", + "\n", + "\n", + "node140357530823264->leaf140357253505616\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530823264->leaf140357253507824\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530823408\n", + "\n", + "dot()\n", + "\n", + "\n", + "\n", + "node140357530823408->leaf140357253508592\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530823408->node140357530823264\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530823168\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "node140357530823168->node140357530824176\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "node140357530823168->node140357530823408\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tsensor.astviz(\"W @ b + x * 3 + h.dot(h)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Show specific code w/o need of a `with` statement:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "G\n", + "\n", + "\n", + "\n", + "leaf140357528652144\n", + "\n", + "    \n", + "2x2\n", + "    \n", + "\n", + "\n", + "    \n", + "W\n", + "    \n", + "\n", + "\n", + "\n", + "leaf140357528651616\n", + "\n", + "@\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528652576\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "b\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528651568\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528652240\n", + "\n", + "    \n", + "2x1\n", + "    \n", + "\n", + "\n", + "    \n", + "x\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528651472\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528653440\n", + "\n", + "3\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528651904\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528649888\n", + "\n", + "    \n", + "2\n", + "    \n", + "\n", + "\n", + "    \n", + "h\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357528653152\n", + "\n", + ".\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253351072\n", + "\n", + "dot\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253349440\n", + "\n", + "(\n", + "\n", + "\n", + "\n", + "\n", + "leaf140357253349728\n", + "\n", + "    \n", + "2\n", + "    \n", + "\n", + "\n", + "    \n", + "h\n", + "    \n", + "\n", + "\n", + "\n", + "\n", + "leaf140357530823456\n", + "\n", + ")\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tsensor.pyviz(\"W @ b + x * 3 + h.dot(h)\", frame)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}