Skip to content

Commit

Permalink
fixing device in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Dec 5, 2022
1 parent c08e4ba commit bca6b2c
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions tests/test_lsmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def testIdentityACase3(self):
self.assertCompatibleSystem(A, xtrue)

def testBidiagonalA(self):
A = lowerBidiagonalMatrix(20, self.n)
A = lowerBidiagonalMatrix(20, self.n, self.dtype, self.device)
xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device))
self.assertCompatibleSystem(A, xtrue)

Expand Down Expand Up @@ -181,7 +181,7 @@ def setUp(self):
self.device = torch.device("cpu")
self.dtype = torch.float64
self.n = 10
self.A = lowerBidiagonalMatrix(20, self.n)
self.A = lowerBidiagonalMatrix(20, self.n, self.dtype, self.device)
self.xtrue = torch.t(torch.arange(self.n, 0, -1, dtype=self.dtype, device=self.device))
self.Afun = self.A.matmul
self.Arfun = self.A.T.matmul
Expand Down Expand Up @@ -228,7 +228,7 @@ def setUp(self) -> None:
super().setUp()


def lowerBidiagonalMatrix(m, n):
def lowerBidiagonalMatrix(m, n, dtype, device):
# This is a simple example for testing LSMR.
# It uses the leading m*n submatrix from
# A = [ 1
Expand All @@ -241,23 +241,29 @@ def lowerBidiagonalMatrix(m, n):
#
# 04 Jun 2010: First version for distribution with lsmr.py
if m <= n:
row = torch.hstack((torch.arange(m, dtype=int), torch.arange(1, m, dtype=int)))
col = torch.hstack((torch.arange(m, dtype=int), torch.arange(m - 1, dtype=int)))
row = torch.hstack((torch.arange(m, dtype=int, device=device), torch.arange(1, m, dtype=int, device=device)))
col = torch.hstack((torch.arange(m, dtype=int, device=device), torch.arange(m - 1, dtype=int, device=device)))
idx = torch.vstack((row, col))
data = torch.hstack((torch.arange(1, m + 1, dtype=float), torch.arange(1, m, dtype=float)))
data = torch.hstack(
(torch.arange(1, m + 1, dtype=dtype, device=device), torch.arange(1, m, dtype=dtype, device=device))
)
return torch.sparse_coo_tensor(idx, data, size=(m, n))
else:
row = torch.hstack((torch.arange(n, dtype=int), torch.arange(1, n + 1, dtype=int)))
col = torch.hstack((torch.arange(n, dtype=int), torch.arange(n, dtype=int)))
row = torch.hstack(
(torch.arange(n, dtype=int, device=device), torch.arange(1, n + 1, dtype=int, device=device))
)
col = torch.hstack((torch.arange(n, dtype=int, device=device), torch.arange(n, dtype=int, device=device)))
idx = torch.vstack((row, col))
data = torch.hstack((torch.arange(1, n + 1, dtype=float), torch.arange(1, n + 1, dtype=float)))
data = torch.hstack(
(torch.arange(1, n + 1, dtype=dtype, device=device), torch.arange(1, n + 1, dtype=dtype, device=device))
)
return torch.sparse_coo_tensor(idx, data, size=(m, n))


def lsmrtest(m, n, damp, dtype, device):
"""Verbose testing of lsmr"""

A = lowerBidiagonalMatrix(m, n)
A = lowerBidiagonalMatrix(m, n, dtype, device)
xtrue = torch.arange(n, 0, -1, dtype=dtype, device=device)
Afun = A.matmul

Expand Down

0 comments on commit bca6b2c

Please sign in to comment.