Skip to content

Commit

Permalink
fix autoblock issue (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Apr 9, 2024
1 parent 197c2e1 commit a2c6d27
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
63 changes: 33 additions & 30 deletions quimb/linalg/autoblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,7 @@ def get_nz(A): # pragma: no cover
return np.nonzero(A)


@njit(
[
"void(int32, int32, List(Set(int32)))",
"void(int64, int64, List(Set(int64)))",
]
)
def _add_to_groups(i, j, groups): # pragma: no cover
for group in groups:
if i in group:
group.add(j)
return
if j in group:
group.add(i)
return
# pair is not in a sector yet - create new one
groups.append({i, j})


@njit(
[
"List(List(int32))(int32[:], int32[:], int_)",
"List(List(int64))(int64[:], int64[:], int_)",
]
)
@njit("List(List(int64))(int64[:], int64[:], int_)")
def compute_blocks(ix, jx, d): # pragma: no cover
"""Find the charge sectors (blocks in matrix terms) given element
coordinates ``ix`` and ``jx`` and total size ``d``.
Expand Down Expand Up @@ -67,18 +44,44 @@ def compute_blocks(ix, jx, d): # pragma: no cover
>>> sectors
[[0], [1, 2, 4, 8], [3, 5, 6, 9, 10, 12], [7, 11, 13, 14], [15]]
"""
groups = [{ix[0], jx[0]}]
groups = []

# go through actual nz
# go through actual nz -> these define edges of a graph and we are
# looking for all connected components (disconnected subgraphs)
for i, j in zip(ix, jx):
_add_to_groups(i, j, groups)
merge = []
for g, group in enumerate(groups):
if i in group:
group.add(j)
merge.append(g)
elif j in group:
group.add(i)
merge.append(g)

if len(merge) == 0:
# new group
groups.append({i, j})

elif len(merge) > 1:
# merge groups
group0 = groups[merge[0]]
for g in merge[-1:0:-1]:
# XXX: just popping here causes numba big problems?
# so we clear and filter empty groups later
other_group = groups[g]
group0.update(other_group)
other_group.clear()

# make sure kernel added as subspace
for i in range(d):
_add_to_groups(i, i, groups)
for group in groups:
if i in group:
break
else: # no break
groups.append({i})

# sort indices in each group and groups by first element
return sorted([sorted(g) for g in groups])
return sorted([sorted(g) for g in groups if g])


@pnjit
Expand Down Expand Up @@ -170,7 +173,7 @@ def _eigh_autoblocked(A, sort=True): # pragma: no cover
gs = [np.array(g) for g in gs]

# diagonalize each charge sector seperately
for i, g in enumerate(gs):
for g in gs:
ng = len(g)

# check if trivial
Expand Down
5 changes: 4 additions & 1 deletion tests/test_accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def test_dot_sparse_dense(self, mat_s, ket_d):
assert isinstance(cq, qarray)
cq = mat_s @ ket_d
assert isinstance(cq, qarray)
cn = mat_s._mul_vector(ket_d)
try:
cn = mat_s._matmul_vector(ket_d)
except AttributeError:
cn = mat_s._mul_vector(ket_d)
assert not issparse(cq)
assert isdense(cq)
assert_allclose(cq.A.ravel(), cn)
Expand Down

0 comments on commit a2c6d27

Please sign in to comment.