Skip to content

Commit

Permalink
svd_truncated: restore scipy fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Feb 8, 2024
1 parent 161b6ca commit b507abc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Release notes for `quimb`.
- [qu.randn](quimb.randn): support `dist="rademacher"`.
- support `dist` and other `randn` options in various TN builders.

**Bug fixes:**

- restore fallback (to `scipy.linalg.svd` with driver='gesvd') behavior for
truncated SVD with numpy backend.


(whats-new-1-7-2)=
## v1.7.2 (2024-01-30)
Expand Down
24 changes: 22 additions & 2 deletions quimb/tensor/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import functools
import operator
import warnings

import numpy as np
import scipy.sparse.linalg as spla
import scipy.linalg as scla
import scipy.linalg.interpolative as sli
import scipy.sparse.linalg as spla
from autoray import (
astype,
backend_like,
Expand Down Expand Up @@ -313,7 +315,6 @@ def _trim_and_renorm_svd_result_numba(
return U, None, VH


@svd_truncated.register("numpy")
@njit # pragma: no cover
def svd_truncated_numba(
x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0
Expand All @@ -325,6 +326,25 @@ def svd_truncated_numba(
)


@svd_truncated.register("numpy")
def svd_truncated_numpy(
x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0
):
"""Numpy version of ``svd_truncated``, trying the accelerated version
first, then falling back to the more stable scipy version.
"""
try:
return svd_truncated_numba(
x, cutoff, cutoff_mode, max_bond, absorb, renorm
)
except np.linalg.LinAlgError as e: # pragma: no cover
warnings.warn(f"Got: {e}, falling back to scipy gesvd driver.")
U, s, VH = scla.svd(x, full_matrices=False, lapack_driver="gesvd")
return _trim_and_renorm_svd_result_numba(
U, s, VH, cutoff, cutoff_mode, max_bond, absorb, renorm
)


@svd_truncated.register("autoray.lazy")
@lazy.core.lazy_cache("svd_truncated")
def svd_truncated_lazy(
Expand Down

0 comments on commit b507abc

Please sign in to comment.