Skip to content

Commit

Permalink
ran black
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Dec 5, 2022
1 parent 3e41db2 commit c08e4ba
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 98 deletions.
168 changes: 85 additions & 83 deletions tests/test_lsmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,30 @@

from torchsparsegradutils.utils import lsmr

#from numpy import array, arange, eye, zeros, ones, sqrt, transpose, hstack
#from numpy.linalg import norm
#from numpy.testing import assert_allclose
#import pytest
#from scipy.sparse import coo_matrix
#from scipy.sparse.linalg._interface import aslinearoperator
#from scipy.sparse.linalg import lsmr
#from .test_lsqr import G, b

def _gettestproblem(dtype,device):
# from numpy import array, arange, eye, zeros, ones, sqrt, transpose, hstack
# from numpy.linalg import norm
# from numpy.testing import assert_allclose
# import pytest
# from scipy.sparse import coo_matrix
# from scipy.sparse.linalg._interface import aslinearoperator
# from scipy.sparse.linalg import lsmr
# from .test_lsqr import G, b


def _gettestproblem(dtype, device):
# Set up a test problem
n = 35
G = torch.eye(n,dtype=dtype,device=device)
G = torch.eye(n, dtype=dtype, device=device)

for jj in range(5):
gg = torch.randn(n, dtype=dtype, device=device)
hh = torch.outer( gg, gg )
hh = torch.outer(gg, gg)
G += hh
#G += (hh + hh.reshape(-1,1)) * 0.5
#G += torch.randn(n, dtype=dtype, device=device) * torch.randn(n, dtype=dtype, device=device)
# G += (hh + hh.reshape(-1,1)) * 0.5
# G += torch.randn(n, dtype=dtype, device=device) * torch.randn(n, dtype=dtype, device=device)

b = torch.randn(n, dtype=dtype, device=device)

return G, b


Expand All @@ -57,16 +58,16 @@ def setUp(self):
self.n = 10
self.m = 10
self.dtype = torch.float64
#self.cdtype = torch.cdouble
# self.cdtype = torch.cdouble

def assertCompatibleSystem(self, A, xtrue):
Afun = A.matmul
b = Afun(xtrue)
x = lsmr(A, b)[0]
#print("A",A)
#print("b",b)
#print("x",x)
#print("xtrue",xtrue)
# print("A",A)
# print("b",b)
# print("x",x)
# print("xtrue",xtrue)
self.assertTrue(torch.allclose(x, xtrue, atol=1e-3, rtol=1e-4))

def testIdentityACase1(self):
Expand All @@ -76,44 +77,44 @@ def testIdentityACase1(self):

def testIdentityACase2(self):
A = torch.eye(self.n, dtype=self.dtype, device=self.device)
xtrue = torch.ones((self.n,1), dtype=self.dtype, device=self.device)
xtrue = torch.ones((self.n, 1), dtype=self.dtype, device=self.device)
self.assertCompatibleSystem(A, xtrue)

def testIdentityACase3(self):
A = torch.eye(self.n, dtype=self.dtype, device=self.device)
xtrue = torch.t(torch.arange(self.n,0,-1, dtype=self.dtype, device=self.device))
xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device))
self.assertCompatibleSystem(A, xtrue)

def testBidiagonalA(self):
A = lowerBidiagonalMatrix(20,self.n)
xtrue = torch.t(torch.arange(self.n,0,-1, dtype=self.dtype, device=self.device))
self.assertCompatibleSystem(A,xtrue)
A = lowerBidiagonalMatrix(20, self.n)
xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device))
self.assertCompatibleSystem(A, xtrue)

def testScalarB(self):
A = torch.tensor([[1.0, 2.0]], dtype=self.dtype, device=self.device)
b = torch.tensor([3.0], dtype=self.dtype, device=self.device)
x = lsmr(A, b)[0]
self.assertTrue(torch.allclose(A.matmul(x), b, atol=1e-3, rtol=1e-4))

#def testComplexX(self):
# def testComplexX(self):
# A = torch.eye(self.n, dtype=self.cdtype, device=self.device)
# xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device) * (1 + 1j))
# self.assertCompatibleSystem(A, xtrue)

#def testComplexX0(self):
# def testComplexX0(self):
# A = 4 * torch.eye(self.n, dtype=self.dtype, device=self.device) + torch.ones((self.n, self.n), dtype=self.dtype, device=self.device)
# xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device))
# b = A.matmul(xtrue)
# x0 = torch.zeros(self.n, dtype=complex, device=self.device)
# x = lsmr(A, b, x0=x0)[0]
# self.assertTrue(torch.allclose(x, xtrue, atol=1e-3, rtol=1e-4))

#def testComplexA(self):
# def testComplexA(self):
# A = 4 * torch.eye(self.n, dtype=self.dtype, device=self.device) + 1j * torch.ones((self.n, self.n), dtype=self.dtype, device=self.device)
# xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device).astype(complex))
# self.assertCompatibleSystem(A, xtrue)

#def testComplexB(self):
# def testComplexB(self):
# A = 4 * torch.eye(self.n, dtype=self.dtype, device=self.device) + torch.ones((self.n, self.n), dtype=self.dtype, device=self.device)
# xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device) * (1 + 1j))
# b = A.matmul(xtrue)
Expand All @@ -127,11 +128,11 @@ def testColumnB(self):
self.assertTrue(torch.allclose(A.matmul(x), b, atol=1e-3, rtol=1e-4))

def testInitialization(self):
G, b = _gettestproblem(self.dtype,self.device)
G, b = _gettestproblem(self.dtype, self.device)
# Test that the default setting is not modified
#x_ref, _, itn_ref, normr_ref, *_ = lsmr(G, b)
# x_ref, _, itn_ref, normr_ref, *_ = lsmr(G, b)
x_ref = lsmr(G, b)[0]
self.assertTrue(torch.allclose(G@x_ref, b, atol=1e-3, rtol=1e-4))
self.assertTrue(torch.allclose(G @ x_ref, b, atol=1e-3, rtol=1e-4))

# Test passing zeros yields similiar result
x0 = torch.zeros(b.shape, dtype=self.dtype, device=self.device)
Expand All @@ -141,9 +142,9 @@ def testInitialization(self):
# Test warm-start with single iteration
x0 = lsmr(G, b, maxiter=1)[0]

#x, _, itn, normr, *_ = lsmr(G, b, x0=x0)
# x, _, itn, normr, *_ = lsmr(G, b, x0=x0)
x = lsmr(G, b, x0=x0)[0]
self.assertTrue(torch.allclose(G@x, b, atol=1e-3, rtol=1e-4))
self.assertTrue(torch.allclose(G @ x, b, atol=1e-3, rtol=1e-4))

# NOTE(gh-12139): This doesn't always converge to the same value as
# ref because error estimates will be slightly different when calculated
Expand All @@ -153,14 +154,14 @@ def testInitialization(self):
# itn == itn_ref means that lsmr(x0) took an extra iteration see above.
# -1 is technically possible but is rare (1 in 100000) so it's more
# likely to be an error elsewhere.
#assert itn - itn_ref in (0, 1)
# assert itn - itn_ref in (0, 1)

# If an extra iteration is performed normr may be 0, while normr_ref
# may be much larger.
#assert normr < normr_ref * (1 + 1e-6)
# assert normr < normr_ref * (1 + 1e-6)

def testVerbose(self):
lsmrtest(20,10,0,self.dtype,self.device)
lsmrtest(20, 10, 0, self.dtype, self.device)


class TestLSMRCUDA(TestLSMR):
Expand All @@ -171,7 +172,7 @@ def setUp(self) -> None:
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")
self.device = torch.device("cuda")
super().setUp()


class TestLSMRReturns(unittest.TestCase):
def setUp(self):
Expand All @@ -191,21 +192,28 @@ def setUp(self):
self.returnValuesX0 = lsmr(self.A, self.b, x0=self.x0)

def test_unchanged_x0(self):
#x, istop, itn, normr, normar, normA, condA, normx = self.returnValuesX0
# x, istop, itn, normr, normar, normA, condA, normx = self.returnValuesX0
x = self.returnValuesX0[0]
self.assertTrue(torch.allclose(self.x00, self.x0, atol=1e-3, rtol=1e-4))

def testNormr(self):
#x, istop, itn, normr, normar, normA, condA, normx = self.returnValues
# x, istop, itn, normr, normar, normA, condA, normx = self.returnValues
x = self.returnValues[0]
self.assertTrue(torch.allclose(self.Afun(x), self.b, atol=1e-3, rtol=1e-4))

def testNormar(self):
#x, istop, itn, normr, normar, normA, condA, normx = self.returnValues
# x, istop, itn, normr, normar, normA, condA, normx = self.returnValues
x = self.returnValuesX0[0]
self.assertTrue(torch.allclose(self.Arfun(self.b - self.Afun(x)), torch.zeros(self.n, dtype=self.dtype, device=self.device), atol=1e-3, rtol=1e-4))

#def testNormx(self):
self.assertTrue(
torch.allclose(
self.Arfun(self.b - self.Afun(x)),
torch.zeros(self.n, dtype=self.dtype, device=self.device),
atol=1e-3,
rtol=1e-4,
)
)

# def testNormx(self):
# x, istop, itn, normr, normar, normA, condA, normx = self.returnValues
# assert norm(x) == pytest.approx(normx)

Expand All @@ -218,7 +226,7 @@ def setUp(self) -> None:
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")
self.device = torch.device("cuda")
super().setUp()


def lowerBidiagonalMatrix(m, n):
# This is a simple example for testing LSMR.
Expand All @@ -233,65 +241,59 @@ def lowerBidiagonalMatrix(m, n):
#
# 04 Jun 2010: First version for distribution with lsmr.py
if m <= n:
row = torch.hstack((torch.arange(m, dtype=int),
torch.arange(1, m, dtype=int)))
col = torch.hstack((torch.arange(m, dtype=int),
torch.arange(m-1, dtype=int)))
idx = torch.vstack((row,col))
data = torch.hstack((torch.arange(1, m+1, dtype=float),
torch.arange(1,m, dtype=float)))
return torch.sparse_coo_tensor(idx, data, size=(m,n))
row = torch.hstack((torch.arange(m, dtype=int), torch.arange(1, m, dtype=int)))
col = torch.hstack((torch.arange(m, dtype=int), torch.arange(m - 1, dtype=int)))
idx = torch.vstack((row, col))
data = torch.hstack((torch.arange(1, m + 1, dtype=float), torch.arange(1, m, dtype=float)))
return torch.sparse_coo_tensor(idx, data, size=(m, n))
else:
row = torch.hstack((torch.arange(n, dtype=int),
torch.arange(1, n+1, dtype=int)))
col = torch.hstack((torch.arange(n, dtype=int),
torch.arange(n, dtype=int)))
idx = torch.vstack((row,col))
data = torch.hstack((torch.arange(1, n+1, dtype=float),
torch.arange(1,n+1, dtype=float)))
return torch.sparse_coo_tensor(idx, data, size=(m,n))
row = torch.hstack((torch.arange(n, dtype=int), torch.arange(1, n + 1, dtype=int)))
col = torch.hstack((torch.arange(n, dtype=int), torch.arange(n, dtype=int)))
idx = torch.vstack((row, col))
data = torch.hstack((torch.arange(1, n + 1, dtype=float), torch.arange(1, n + 1, dtype=float)))
return torch.sparse_coo_tensor(idx, data, size=(m, n))


def lsmrtest(m, n, damp, dtype, device):
"""Verbose testing of lsmr"""

A = lowerBidiagonalMatrix(m,n)
xtrue = torch.arange(n,0,-1, dtype=dtype, device=device)
A = lowerBidiagonalMatrix(m, n)
xtrue = torch.arange(n, 0, -1, dtype=dtype, device=device)
Afun = A.matmul

b = Afun(xtrue)

atol = 1.0e-7
btol = 1.0e-7
conlim = 1.0e+10
itnlim = 10*n
conlim = 1.0e10
itnlim = 10 * n
show = 1

#x, istop, itn, normr, normar, norma, conda, normx \
# x, istop, itn, normr, normar, norma, conda, normx \
# = lsmr(A, b, damp, atol, btol, conlim, itnlim, show)
x = lsmr(A, b, damp=damp, atol=atol, btol=btol, conlim=conlim, maxiter=itnlim)[0]

j1 = min(n,5)
j2 = max(n-4,1)
print(' ')
print('First elements of x:')
str = ['%10.4f' % (xi) for xi in x[0:j1]]
print(''.join(str))
print(' ')
print('Last elements of x:')
str = ['%10.4f' % (xi) for xi in x[j2-1:]]
print(''.join(str))
j1 = min(n, 5)
j2 = max(n - 4, 1)
print(" ")
print("First elements of x:")
str = ["%10.4f" % (xi) for xi in x[0:j1]]
print("".join(str))
print(" ")
print("Last elements of x:")
str = ["%10.4f" % (xi) for xi in x[j2 - 1 :]]
print("".join(str))

r = b - Afun(x)
r2 = torch.sqrt(torch.norm(r)**2 + (damp*torch.norm(x))**2)
print(' ')
#str = 'normr (est.) %17.10e' % (normr)
str2 = 'normr (true) %17.10e' % (r2)
#print(str)
r2 = torch.sqrt(torch.norm(r) ** 2 + (damp * torch.norm(x)) ** 2)
print(" ")
# str = 'normr (est.) %17.10e' % (normr)
str2 = "normr (true) %17.10e" % (r2)
# print(str)
print(str2)
print(' ')
print(" ")


if __name__ == "__main__":
#lsmrtest(20,10,0)
# lsmrtest(20,10,0)
unittest.main()
2 changes: 1 addition & 1 deletion torchsparsegradutils/utils/bicgstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def bicgstab(
x = torch.zeros(n, dtype=res_dtype, device=res_device)
else:
x = initial_guess.clone()

# matvec_max = kwargs.get('matvec_max', 2*n)
matvec_max = 2 * n if settings.matvec_max is None else settings.matvec_max

Expand Down
Loading

0 comments on commit c08e4ba

Please sign in to comment.