Skip to content

Commit

Permalink
update decomp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 17, 2023
1 parent d3cdd9e commit 270c810
Showing 1 changed file with 42 additions and 19 deletions.
61 changes: 42 additions & 19 deletions quimb/tensor/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@
from ..linalg import rand_linalg


_CUTOFF_MODE_MAP = {
"abs": 1,
"rel": 2,
"sum2": 3,
"rsum2": 4,
"sum1": 5,
"rsum1": 6,
}


def map_cutoff_mode(cutoff_mode):
"""Map mode to an integer for compatibility with numba."""
return _CUTOFF_MODE_MAP.get(cutoff_mode, cutoff_mode)


# some convenience functions for multiplying diagonals


Expand Down Expand Up @@ -82,14 +97,14 @@ def sgn(x):
"""Get the 'sign' of ``x``, such that ``x / sgn(x)`` is real and
non-negative.
"""
x0 = (x == 0.0)
x0 = x == 0.0
return (x + x0) / (do("abs", x) + x0)


@sgn.register("numpy")
@njit # pragma: no cover
def sgn_numba(x):
x0 = (x == 0.0)
x0 = x == 0.0
return (x + x0) / (np.abs(x) + x0)


Expand Down Expand Up @@ -176,10 +191,10 @@ def svd_truncated(
Parameters
----------
cutoff : float
cutoff : float, optional
Singular value cutoff threshold, if ``cutoff <= 0.0``, then only
``max_bond`` is used.
cutoff_mode : {1, 2, 3, 4, 5, 6}
cutoff_mode : {1, 2, 3, 4, 5, 6}, optional
How to perform the trim:
- 1: ['abs'], trim values below ``cutoff``
Expand All @@ -189,12 +204,12 @@ def svd_truncated(
- 5: ['sum1'], trim s.t. ``sum(s_trim**1) < cutoff``.
- 6: ['rsum1'], trim s.t. ``sum(s_trim**1) < sum(s**1) * cutoff``.
max_bond : int
max_bond : int, optional
An explicit maximum bond dimension, use -1 for none.
absorb : {-1, 0, 1, None}
absorb : {-1, 0, 1, None}, optional
How to absorb the singular values. -1: left, 0: both, 1: right and
None: don't absorb (return).
renorm : {0, 1}
renorm : {0, 1}, optional
Whether to renormalize the singular values (depends on `cutoff_mode`).
"""
with backend_like(backend):
Expand Down Expand Up @@ -313,7 +328,12 @@ def svd_truncated_numba(
@svd_truncated.register("autoray.lazy")
@lazy.core.lazy_cache("svd_truncated")
def svd_truncated_lazy(
x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0,
x,
cutoff=-1.0,
cutoff_mode=4,
max_bond=-1,
absorb=0,
renorm=0,
):
if cutoff != 0.0:
raise ValueError("Can't handle dynamic cutoffs in lazy mode.")
Expand All @@ -326,7 +346,7 @@ def svd_truncated_lazy(
lsvdt = x.to(
fn=get_lib_fn(x.backend, "svd_truncated"),
args=(x, cutoff, cutoff_mode, max_bond, absorb, renorm),
shape=(3,)
shape=(3,),
)

U = lsvdt.to(operator.getitem, (lsvdt, 0), shape=(m, k))
Expand Down Expand Up @@ -364,14 +384,14 @@ def lu_truncated(
)

with backend_like(backend):
PL, U = do('scipy.linalg.lu', x, permute_l=True)
PL, U = do("scipy.linalg.lu", x, permute_l=True)

sl = do('sum', do('abs', PL), axis=0)
su = do('sum', do('abs', U), axis=1)
sl = do("sum", do("abs", PL), axis=0)
su = do("sum", do("abs", U), axis=1)

if cutoff_mode == 2:
abs_cutoff_l = cutoff * do('max', sl)
abs_cutoff_u = cutoff * do('max', su)
abs_cutoff_l = cutoff * do("max", sl)
abs_cutoff_u = cutoff * do("max", su)
elif cutoff_mode == 1:
abs_cutoff_l = abs_cutoff_u = cutoff
else:
Expand Down Expand Up @@ -943,13 +963,13 @@ def isometrize_cayley(x, backend):
"pad", x, [[0, d - m], [0, d - n]], "constant", constant_values=0.0
)
x = x - dag(x)
x = x / 2.
x = x / 2.0
if backend == "torch":
# XXX: move device handling upstream in to autoray?
Id = do("eye", d, like=x, device=x.device)
else:
Id = do("eye", d, like=x)
Q = do('linalg.solve', Id - x, Id + x)
Q = do("linalg.solve", Id - x, Id + x)
return Q[:m, :n]


Expand All @@ -974,7 +994,7 @@ def isometrize_modified_gram_schmidt(A, backend=None):
def isometrize_householder(X, backend=None):
with backend_like(backend):
X = do("tril", X, -1)
tau = 2. / (1. + do("sum", do("conj", X) * X, 0))
tau = 2.0 / (1.0 + do("sum", do("conj", X) * X, 0))
Q = do("linalg.householder_product", X, tau)
return Q

Expand Down Expand Up @@ -1125,7 +1145,7 @@ def squared_op_to_reduced_factor_numba(x2, dl, dr, right=True):


def compute_oblique_projectors(
Rl, Rr, max_bond, cutoff, absorb="both", **compress_opts
Rl, Rr, max_bond, cutoff, absorb="both", cutoff_mode=4, **compress_opts
):
"""Compute the oblique projectors for two reduced factor matrices that
describe a gauge on a bond. Concretely, assuming that ``Rl`` and ``Rr`` are
Expand Down Expand Up @@ -1162,12 +1182,15 @@ def compute_oblique_projectors(
if max_bond is None:
max_bond = -1

cutoff_mode = map_cutoff_mode(cutoff_mode)

Ut, st, VHt = svd_truncated(
Rl @ Rr,
max_bond=max_bond,
cutoff=cutoff,
absorb=None,
**compress_opts
cutoff_mode=cutoff_mode,
**compress_opts,
)
st_sqrt = do("sqrt", st)

Expand Down

0 comments on commit 270c810

Please sign in to comment.