-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a simple notebook aiming at comparing different solvers
- Loading branch information
Showing
1 changed file
with
389 additions
and
0 deletions.
There are no files selected for viewing
389 changes: 389 additions & 0 deletions
389
torchsparsegradutils/tests/spd_forward_solve_notebook.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,389 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyN0JFpiKDeXK9tTd/a7xn9o", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"accelerator": "GPU", | ||
"gpuClass": "standard" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/cai4cai/torchsparsegradutils/blob/test-notebook/tests/spd_forward_solve_notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# A basic notebook to compare different solvers for a SPD matrix\n", | ||
"\n", | ||
"First import all necessary modules." | ||
], | ||
"metadata": { | ||
"id": "xvuQc5HyjX4t" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "3hHTa-xCjIQv", | ||
"outputId": "cf255d05-b0ab-4a79-9765-1ed23fbff784" | ||
}, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Running PyTorch version: 1.13.0+cu116\n", | ||
"Default GPU is Tesla T4\n", | ||
"Running on cuda\n", | ||
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | ||
"Collecting git+https://github.com/cai4cai/torchsparsegradutils\n", | ||
" Cloning https://github.com/cai4cai/torchsparsegradutils to /tmp/pip-req-build-plwawj9f\n", | ||
" Running command git clone -q https://github.com/cai4cai/torchsparsegradutils /tmp/pip-req-build-plwawj9f\n", | ||
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | ||
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", | ||
" Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", | ||
"Requirement already satisfied: torch>=1.13 in /usr/local/lib/python3.8/dist-packages (from torchsparsegradutils==0.0.1) (1.13.0+cu116)\n", | ||
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.13->torchsparsegradutils==0.0.1) (4.4.0)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"print(f'Running PyTorch version: {torch.__version__}')\n", | ||
"\n", | ||
"torchdevice = torch.device('cpu')\n", | ||
"if torch.cuda.is_available():\n", | ||
" torchdevice = torch.device('cuda')\n", | ||
" print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))\n", | ||
"print('Running on ' + str(torchdevice))\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import scipy.io\n", | ||
"import scipy.sparse.linalg\n", | ||
"\n", | ||
"import cupyx\n", | ||
"import cupyx.scipy.sparse as csp\n", | ||
"\n", | ||
"import jax\n", | ||
"from jax.config import config\n", | ||
"config.update(\"jax_enable_x64\", True)\n", | ||
"\n", | ||
"!pip install git+https://github.com/cai4cai/torchsparsegradutils\n", | ||
"import torchsparsegradutils as tsgu\n", | ||
"import torchsparsegradutils.utils\n", | ||
"import torchsparsegradutils.cupy as tsgucupy\n", | ||
"import torchsparsegradutils.jax as tsgujax\n", | ||
"\n", | ||
"import time\n", | ||
"import urllib\n", | ||
"import os.path\n", | ||
"import tarfile" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Now load an example SPD matrix and create a random RHS vector" | ||
], | ||
"metadata": { | ||
"id": "Vp-qCHEDkQbm" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def load_mat_from_suitesparse_collection(dirname,matname):\n", | ||
" base_url = 'https://suitesparse-collection-website.herokuapp.com/MM/'\n", | ||
" url = base_url + dirname + '/' + matname + '.tar.gz'\n", | ||
" compressedlocalfile = matname + '.tar.gz'\n", | ||
" if not os.path.exists(compressedlocalfile):\n", | ||
" print(f'Downloading {url}')\n", | ||
" urllib.request.urlretrieve(url, filename=compressedlocalfile)\n", | ||
"\n", | ||
" localfile = './' + matname + '/' + matname + '.mtx'\n", | ||
" if not os.path.exists(localfile):\n", | ||
" print(f'untarring {compressedlocalfile}')\n", | ||
" srctarfile = tarfile.open(compressedlocalfile)\n", | ||
" srctarfile.extractall('./')\n", | ||
" srctarfile.close()\n", | ||
"\n", | ||
" A_np_coo = scipy.io.mmread(localfile)\n", | ||
" print(f'Loaded suitesparse matrix {dirname}/{matname}: type={type(A_np_coo)}, shape={A_np_coo.shape}')\n", | ||
" return A_np_coo\n", | ||
"\n", | ||
"A_np_coo = load_mat_from_suitesparse_collection('Rothberg','cfd2')\n", | ||
"A_np_csr = scipy.sparse.csr_matrix(A_np_coo)\n", | ||
"\n", | ||
"b_np = np.random.randn(A_np_coo.shape[1])\n", | ||
"print(f'Created random RHS with shape={b_np.shape}')" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "0CzeubR_kUVA", | ||
"outputId": "dc3ff1db-1425-4a82-cf54-0ad60f81963c" | ||
}, | ||
"execution_count": 2, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Loaded suitesparse matrix Rothberg/cfd2: type=<class 'scipy.sparse.coo.coo_matrix'>, shape=(123440, 123440)\n", | ||
"Created random RHS with shape=(123440,)\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Define some helper functions to run the tests" | ||
], | ||
"metadata": { | ||
"id": "MUacEhsqoYxZ" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def scipy_test(A, b, tested_solver, print_string):\n", | ||
" t = time.time()\n", | ||
" x = tested_solver(A, b)\n", | ||
" elapsed = time.time() - t\n", | ||
" resnorm = scipy.linalg.norm(A @ x - b)\n", | ||
" print(f'{print_string} took {elapsed:.2f} seconds - resnorm={resnorm:.2e}')\n", | ||
" return elapsed, resnorm" | ||
], | ||
"metadata": { | ||
"id": "1jYXGRDlnmYA" | ||
}, | ||
"execution_count": 3, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Run the tests with the scipy routines (it can take a while)" | ||
], | ||
"metadata": { | ||
"id": "Th8gF8MnoeGR" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"run_scipy = True\n", | ||
"if run_scipy:\n", | ||
" t, n = scipy_test(A_np_coo, b_np, scipy.sparse.linalg.spsolve, 'scipy.spsolve COO')\n", | ||
" t, n = scipy_test(A_np_csr, b_np, scipy.sparse.linalg.spsolve, 'scipy.spsolve CSR')\n", | ||
" t, n = scipy_test(A_np_coo, b_np, lambda A, b: scipy.sparse.linalg.cg(A,b)[0], 'scipy.cg COO')\n", | ||
" t, n = scipy_test(A_np_csr, b_np, lambda A, b: scipy.sparse.linalg.cg(A,b)[0], 'scipy.cg CSR')\n", | ||
" t, n = scipy_test(A_np_coo, b_np, lambda A, b: scipy.sparse.linalg.bicgstab(A,b)[0], 'scipy.bicgstab COO')\n", | ||
" t, n = scipy_test(A_np_csr, b_np, lambda A, b: scipy.sparse.linalg.bicgstab(A,b)[0], 'scipy.bicgstab CSR')\n", | ||
" t, n = scipy_test(A_np_coo, b_np, lambda A, b: scipy.sparse.linalg.minres(A,b)[0], 'scipy.minres COO')\n", | ||
" t, n = scipy_test(A_np_csr, b_np, lambda A, b: scipy.sparse.linalg.minres(A,b)[0], 'scipy.minres CSR')\n", | ||
"else:\n", | ||
" print('Skipping scipy tests')" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "4KaNyZpDogS4", | ||
"outputId": "34e0dc0b-392b-4485-ac55-1233dde26aaf" | ||
}, | ||
"execution_count": 4, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"/usr/local/lib/python3.8/dist-packages/scipy/sparse/linalg/dsolve/linsolve.py:144: SparseEfficiencyWarning: spsolve requires A be CSC or CSR matrix format\n", | ||
" warn('spsolve requires A be CSC or CSR matrix format',\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"scipy.spsolve COO took 108.62 seconds - resnorm=2.15e-10\n", | ||
"scipy.spsolve CSR took 98.29 seconds - resnorm=2.49e-10\n", | ||
"scipy.cg COO took 72.09 seconds - resnorm=3.50e-03\n", | ||
"scipy.cg CSR took 55.59 seconds - resnorm=3.50e-03\n", | ||
"scipy.bicgstab COO took 112.65 seconds - resnorm=2.31e-03\n", | ||
"scipy.bicgstab CSR took 88.79 seconds - resnorm=2.31e-03\n", | ||
"scipy.minres COO took 0.54 seconds - resnorm=3.05e+01\n", | ||
"scipy.minres CSR took 0.38 seconds - resnorm=3.05e+01\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Create corresponding PyTorch tensors" | ||
], | ||
"metadata": { | ||
"id": "BqOzcOXYrMY0" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"A_tcpu_csr = tsgucupy.c2t_csr(A_np_csr)\n", | ||
"b_tcpu = torch.from_numpy(b_np)\n", | ||
"\n", | ||
"if torch.cuda.is_available():\n", | ||
" A_tgpu_csr = A_tcpu_csr.to(torchdevice)\n", | ||
" b_tgpu = b_tcpu.to(torchdevice)" | ||
], | ||
"metadata": { | ||
"id": "DHqwjtVOrRyH", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "307faf94-ee4a-4943-d483-ad084246ee0d" | ||
}, | ||
"execution_count": 5, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"/usr/local/lib/python3.8/dist-packages/torchsparsegradutils/cupy/cupy_bindings.py:52: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:54.)\n", | ||
" x_torch = torch.sparse_csr_tensor(ind_ptr_t, idices_t, data_t, x_cupy.shape)\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Define some helper function to run the tests with pytorch" | ||
], | ||
"metadata": { | ||
"id": "zRZ7yXiZrjJ3" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def torch_test(A, b, tested_solver, print_string):\n", | ||
" t = time.time()\n", | ||
" x = tested_solver(A, b)\n", | ||
" elapsed = time.time() - t\n", | ||
" resnorm = torch.norm(A @ x - b).cpu().numpy()\n", | ||
" print(f'{print_string} took {elapsed:.2f} seconds - resnorm={resnorm:.2e}')\n", | ||
" print(f'GPU memory allocated: {torch.cuda.memory_allocated(device=torchdevice)/10**9:.2f}Gb'\n", | ||
" f' - max allocated: {torch.cuda.max_memory_allocated(device=torchdevice)/10**9:.2f}Gb')\n", | ||
" #print(torch.cuda.memory_summary(abbreviated=True))\n", | ||
" return elapsed, resnorm" | ||
], | ||
"metadata": { | ||
"id": "jE0EFZiJrw71" | ||
}, | ||
"execution_count": 6, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Run the tests with our pytorch routines (this can take a while)" | ||
], | ||
"metadata": { | ||
"id": "-nSPnGeosFGr" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"print(f'GPU memory allocated: {torch.cuda.memory_allocated(device=torchdevice)/10**9:.2f}Gb'\n", | ||
" f' - max allocated: {torch.cuda.max_memory_allocated(device=torchdevice)/10**9:.2f}Gb')\n", | ||
"#print(torch.cuda.memory_summary(abbreviated=True))\n", | ||
"\n", | ||
"#t, n = torch_test(A_tcpu_csr, b_tcpu, tsgu.sparse_generic_solve, 'tsgu.sparse_generic_solve CPU CSR')\n", | ||
"if torch.cuda.is_available():\n", | ||
" t, n = torch_test(A_tgpu_csr, b_tgpu, tsgu.sparse_generic_solve, 'tsgu.sparse_generic_solve GPU CSR')\n", | ||
"\n", | ||
"#t, n = torch_test(A_tcpu_csr, b_tcpu, tsgu.sparse_generic_lstsq, 'tsgu.sparse_generic_lstsq CPU CSR')\n", | ||
"#if torch.cuda.is_available():\n", | ||
"# t, n = torch_test(A_tgpu_csr, b_tgpu, tsgu.sparse_generic_lstsq, 'tsgu.sparse_generic_lstsq GPU CSR')\n", | ||
"\n", | ||
"mysolver = lambda A, b: tsgu.sparse_generic_solve(A,b,solve=tsgu.utils.minres)\n", | ||
"if torch.cuda.is_available():\n", | ||
" t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_generic_solve minres GPU CSR')\n", | ||
"\n", | ||
"mysolver = lambda A, b: tsgu.sparse_generic_solve(A,b,solve=tsgu.utils.linear_cg)\n", | ||
"if torch.cuda.is_available():\n", | ||
" t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_generic_solve cg GPU CSR')\n", | ||
"\n", | ||
"mysolver = lambda A, b: tsgu.sparse_generic_solve(A,b,solve=tsgu.utils.bicgstab)\n", | ||
"if torch.cuda.is_available():\n", | ||
" t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_generic_solve bicgstab GPU CSR')\n", | ||
"\n", | ||
"jaxsolver = None\n", | ||
"mysolver = lambda A, b: tsgujax.sparse_solve_j4t(A,b)\n", | ||
"if torch.cuda.is_available():\n", | ||
" t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_solve_j4t cg GPU CSR')\n", | ||
"\n", | ||
"#cpsolver = lambda AA, BB: csp.linalg.cg(AA,BB)[0]\n", | ||
"#mysolver = lambda A, b: tsgucupy.sparse_solve_c4t(A,b,solve=cpsolver)\n", | ||
"#mysolver = lambda A, b: tsgucupy.sparse_solve_c4t(A,b)\n", | ||
"#t, n = torch_test(A_tcpu_csr, b_tcpu, mysolver, 'tsgu.sparse_solve_c4t cg CPU CSR')\n", | ||
"#if torch.cuda.is_available():\n", | ||
"# t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_solve_c4t cg GPU CSR')" | ||
], | ||
"metadata": { | ||
"id": "3HH-nGDSsIXy", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "7a8dacff-efaa-416a-f6d5-9a87a2b84157" | ||
}, | ||
"execution_count": 7, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"GPU memory allocated: 0.04Gb - max allocated: 0.04Gb\n", | ||
"tsgu.sparse_generic_solve GPU CSR took 0.62 seconds - resnorm=1.93e+00\n", | ||
"GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n", | ||
"tsgu.sparse_generic_solve minres GPU CSR took 0.46 seconds - resnorm=1.93e+00\n", | ||
"GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n", | ||
"tsgu.sparse_generic_solve cg GPU CSR took 0.06 seconds - resnorm=2.43e+02\n", | ||
"GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n", | ||
"tsgu.sparse_generic_solve bicgstab GPU CSR took 9.35 seconds - resnorm=2.60e-04\n", | ||
"GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n", | ||
"tsgu.sparse_solve_j4t cg GPU CSR took 6.39 seconds - resnorm=2.74e-03\n", | ||
"GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n" | ||
] | ||
} | ||
] | ||
} | ||
] | ||
} |