Skip to content

Commit

Permalink
update ytools.mattype: add option to output cholesky decomp and bette…
Browse files Browse the repository at this point in the history
…r type checking

update ytools.eig_si to request cholesky decomp from mattype (faster)
  • Loading branch information
twmacro committed Oct 1, 2023
1 parent 8e13af7 commit 8e621e1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 45 deletions.
32 changes: 31 additions & 1 deletion pyyeti/tests/test_ytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import os
from pyyeti import ytools
import scipy.linalg as linalg
from scipy import linalg
import pytest


Expand Down Expand Up @@ -130,6 +130,36 @@ def test_mattype():
ytools.mattype(c, "badtype")


def test_mattype2():
# real a
a = np.array([[346500.0, 1e-7], [2.1e-7, 1000000.1]])
assert 13 == ytools.mattype(a)[0]
t, mattypes, ch = ytools.mattype(a, return_cholesky=True)
assert t == 13
assert np.allclose(ch, linalg.cholesky(a))

# complex a
a = np.array([[346500.0, 1e-7 * (1 + 1j)], [2.1e-7 * (1 - 1j), 1000000.1]])
assert 15 == ytools.mattype(a)[0]
t, mattypes, ch = ytools.mattype(a, return_cholesky=True)
assert t == 15
assert np.allclose(ch, linalg.cholesky(a))

# real a
a = np.array([[346500.0, 1e-7], [2000.1e-7, 1000000.1]])
assert 0 == ytools.mattype(a)[0]
t, mattypes, ch = ytools.mattype(a, return_cholesky=True)
assert t == 0
assert ch is None

# complex a
a = np.array([[346500.0, 1e-7 * (1 + 1j)], [2000.1e-7 * (1 - 1j), 1000000.1]])
assert 0 == ytools.mattype(a)[0]
t, mattypes, ch = ytools.mattype(a, return_cholesky=True)
assert t == 0
assert ch is None


def test_save_load():
a = np.arange(18).reshape(2, 3, 3)
b = np.arange(3)
Expand Down
116 changes: 72 additions & 44 deletions pyyeti/ytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,40 @@ def isdiag(A, tol=1e-12):
return max_off <= tol * max_on


def mattype(A, mtype=None):
def _check_symm_herm(A, mattypes):
Atype = 0
if np.allclose(A, A.T):
Atype |= mattypes["symmetric"]

elif np.iscomplexobj(A) and np.allclose(A, A.T.conj()):
Atype |= mattypes["hermitian"]

if isdiag(A):
Atype |= mattypes["diagonal"] | mattypes["symmetric"]
if np.iscomplexobj(A):
Atype |= mattypes["hermitian"]
d = np.diag(A)
if np.allclose(1, d):
Atype |= mattypes["identity"]

return Atype


def _check_cholesky(Atype, A, mattypes):
chol = None
if (Atype & mattypes["symmetric"] and np.isrealobj(A)) or (
Atype & mattypes["hermitian"]
):
try:
chol = linalg.cholesky(A)
except linalg.LinAlgError:
pass
else:
Atype |= mattypes["posdef"]
return Atype, chol


def mattype(A, mtype=None, return_cholesky=False):
"""
Checks contents of square matrix `A` to see if it is symmetric,
hermitian, positive-definite, diagonal, and identity.
Expand All @@ -731,6 +764,10 @@ def mattype(A, mtype=None):
this case, True is returned if `A` is of the type specified or
False otherwise. If None, `Atype` (if `A` is not None) and
`mattypes` is returned. `mtype` is ignored if `A` is None.
return_cholesky : bool; optional
If True, the output of :func:`scipy.linalg.cholesky` is
returned if computed. Output will be None if `A` is not
positive-definite. See example usages below.
Returns
-------
Expand All @@ -753,13 +790,25 @@ def mattype(A, mtype=None):
Not returned if `mtype` is specified. This is the only return
if `A` is None.
chol : 2d ndarray or None
See `return_cholesky` above. If returned, `chol` will be the
output of :func:`scipy.linalg.cholesky` (with default
settings) or None, depending on whether matrix is
positive-definite or not.
Notes
-----
Here are some example usages::
mattype(A) # returns (Atype, mattypes)
mattype(A, 'symmetric') # returns True or False
mattype(None) # returns mattypes
Here are some example usages:
========================================== =======================
Usage Returns
========================================== =======================
mattype(A) (Atype, mattypes)
mattype(A, return_cholesky=True) (Atype, mattypes, chol)
mattype(A, 'symmetric') True or False
mattype(A, 'posdef', return_cholesky=True) (True or False, chol)
mattype(None) mattypes
========================================== =======================
See also
--------
Expand Down Expand Up @@ -809,51 +858,30 @@ def mattype(A, mtype=None):
if mtype is None:
if A.ndim != 2 or A.shape[0] != A.shape[1]:
return Atype, mattypes
if np.allclose(A, A.T):
Atype |= mattypes["symmetric"]
if np.isrealobj(A):
try:
linalg.cholesky(A)
except linalg.LinAlgError:
pass
else:
Atype |= mattypes["posdef"]
elif np.iscomplexobj(A) and np.allclose(A, A.T.conj()):
Atype |= mattypes["hermitian"]
try:
linalg.cholesky(A)
except linalg.LinAlgError:
pass
else:
Atype |= mattypes["posdef"]
if isdiag(A):
Atype |= mattypes["diagonal"]
d = np.diag(A)
if np.allclose(1, d):
Atype |= mattypes["identity"]

Atype = _check_symm_herm(A, mattypes)
Atype, chol = _check_cholesky(Atype, A, mattypes)

if return_cholesky:
return Atype, mattypes, chol
return Atype, mattypes

if A.ndim != 2 or A.shape[0] != A.shape[1]:
return False

if mtype == "symmetric":
return np.allclose(A, A.T)
return np.allclose(A, A.T) or isdiag(A)

if mtype == "hermitian":
return np.allclose(A, A.T.conj())
return np.iscomplexobj(A) and (np.allclose(A, A.T.conj()) or isdiag(A))

if mtype == "posdef":
if np.isrealobj(A):
if not np.allclose(A, A.T):
return False
else:
if not np.allclose(A, A.T.conj()):
return False
try:
linalg.cholesky(A)
return True
except linalg.LinAlgError:
return False
Atype = _check_symm_herm(A, mattypes)
Atype, chol = _check_cholesky(Atype, A, mattypes)
ret = bool(Atype & mattypes["posdef"])
if return_cholesky:
return ret, chol
return ret

if mtype in ("diagonal", "identity"):
if isdiag(A):
Expand Down Expand Up @@ -1148,20 +1176,20 @@ def eig_si(
# Mk = (Mk + Mk.T) / 2

# solve subspace eigenvalue problem:
mtp = mattype(Mk, "posdef")
mtp, Mkuu = mattype(Mk, "posdef", return_cholesky=True)
if not mtp:
factor = 1000 * eps
pc = 0
while 1:
pc += 1
Mk += np.diag(np.diag(Mk) * factor)
factor *= 10.0
mtp = mattype(Mk, "posdef")
mtp, Mkuu = mattype(Mk, "posdef", return_cholesky=True)
if mtp or pc > 5:
break

if mtp:
Mkll = linalg.cholesky(Mk, lower=False).T
Mkll = Mkuu.T # linalg.cholesky(Mk, lower=False).T
Kkmod = linalg.solve_triangular(
Mkll, linalg.solve_triangular(Mkll, Kk, lower=True).T, lower=True
)
Expand Down

0 comments on commit 8e621e1

Please sign in to comment.