Skip to content

Commit

Permalink
Ported fix from cornellius-gp/linear_operator@9b58b4a - Fix normaliza…
Browse files Browse the repository at this point in the history
…tion of initial guess in linear_cg()
  • Loading branch information
tvercaut committed Oct 24, 2023
1 parent a9fd056 commit 194ce5c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
20 changes: 20 additions & 0 deletions torchsparsegradutils/tests/test_linear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,26 @@ def test_batch_cg_with_tridiag(self):
approx_eigs = torch.linalg.eigvalsh(t_mats[j, i])
self.assertTrue(torch.allclose(eigs, approx_eigs, atol=1e-3, rtol=1e-4))

def test_batch_cg_init(self):
batch = 5
size = 100
matrix = torch.randn(batch, size, size, dtype=torch.float64)
matrix = matrix.matmul(matrix.mT)
matrix.div_(matrix.norm())
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))

# Initial solve
rhs = torch.randn(batch, size, 50, dtype=torch.float64)
solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size, max_tridiag_iter=0)

# Initialize with solve
solves_with_init = linear_cg(matrix.matmul, rhs=rhs, max_iter=1, initial_guess=solves, max_tridiag_iter=0)

# Check cg
matrix_chol = torch.linalg.cholesky(matrix)
actual = torch.cholesky_solve(rhs, matrix_chol)
self.assertTrue(torch.allclose(solves_with_init, actual, atol=1e-3, rtol=1e-4))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torchsparsegradutils/utils/linear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def linear_cg(

# Let's normalize. We'll un-normalize afterwards
rhs = rhs.div(rhs_norm)
initial_guess = initial_guess.div(rhs_norm)

# residual: residual_{0} = b_vec - lhs x_{0}
residual = rhs - matmul_closure(initial_guess)
Expand Down

0 comments on commit 194ce5c

Please sign in to comment.