Skip to content

Commit

Permalink
PBC support for 2D and 3D HOTRG, CTMRG
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Mar 7, 2024
1 parent 6e522e6 commit 5ba7724
Show file tree
Hide file tree
Showing 4 changed files with 376 additions and 115 deletions.
125 changes: 91 additions & 34 deletions quimb/tensor/tensor_2d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Classes and algorithms related to 2D tensor networks.
"""
"""Classes and algorithms related to 2D tensor networks."""

import re
import random
import functools
Expand Down Expand Up @@ -256,7 +256,35 @@ def __init__(self, tn, xrange, yrange, from_which, stepsize=1):
self.sweep = range(self.imax, self.imin - 1, -stepsize)
self.istep = -stepsize

self.sweep_other = range(self.jmin, self.jmax + 1)
@functools.cached_property
def sweep_other(self):
return range(self.jmin, self.jmax + 1)

@functools.cached_property
def cyclic_x(self):
return self.is_cyclic_x(
(self.jmin + self.jmax) // 2,
self.imin,
self.imax,
)

@functools.cached_property
def cyclic_y(self):
return self.is_cyclic_y(
(self.imin + self.imax) // 2,
self.jmin,
self.jmax,
)

def get_jnext(self, j):
if j == self.jmax:
if self.cyclic_y:
# wrap around
return self.jmin
# no more steps
return None
# normal step
return j + 1

def get_opposite_env_fn(self):
"""Get the function and location label for contracting boundaries in
Expand Down Expand Up @@ -682,14 +710,21 @@ def is_cyclic_x(self, j=None, imin=None, imax=None):
"""Check if the x dimension is cyclic (periodic), specifically whether
a bond exists between ``(imin, j)`` and ``(imax, j)``, with default
values of ``imin = 0`` and ``imax = Lx - 1``, and ``j`` at the center
of the lattice.
of the lattice. If ``imin`` and ``imax`` are adjacent then this is
considered False, since there is no 'extra' connectivity.
"""
if j is None:
j = self.Ly // 2
if imin is None:
imin = 0
if imax is None:
imax = self.Lx - 1

if abs(imax - imin) <= 1:
# first and last sites already connected -> a bit undefined
return False

if j is None:
j = self.Ly // 2

return bool(
bonds(
self[self.site_tag(imin, j)],
Expand All @@ -701,14 +736,21 @@ def is_cyclic_y(self, i=None, jmin=None, jmax=None):
"""Check if the y dimension is cyclic (periodic), specifically whether
a bond exists between ``(i, jmin)`` and ``(i, jmax)``, with default
values of ``jmin = 0`` and ``jmax = Ly - 1``, and ``i`` at the center
of the lattice.
of the lattice. If ``jmin`` and ``jmax`` are adjacent then this is
considered False, since there is no 'extra' connectivity.
"""
if i is None:
i = self.Lx // 2
if jmin is None:
jmin = 0
if jmax is None:
jmax = self.Ly - 1

if abs(jmax - jmin) <= 1:
# first and last sites already connected -> a bit undefined
return False

if i is None:
i = self.Lx // 2

return bool(
bonds(
self[self.site_tag(i, jmin)],
Expand Down Expand Up @@ -1530,20 +1572,17 @@ def _contract_boundary_projector(
compress_opts = ensure_dict(compress_opts)

r = Rotator2D(self, xrange, yrange, from_which)
cyclic_y = r.is_cyclic_y()

for i0, i1 in pairwise(r.sweep):
for i, inext in pairwise(r.sweep):
# we compute the projectors from an untouched copy
tn_calc = self.copy()

for j in r.sweep_other:
tag_ij = r.site_tag(i0, j)
tag_ip1j = r.site_tag(i1, j)

if (j < r.jmax) or cyclic_y:
ltags = (tag_ij, tag_ip1j)
jp1 = j + 1 if j < r.jmax else r.jmin
rtags = r.site_tag(i0, jp1), r.site_tag(i1, jp1)
# this handles cyclic boundary conditions
jnext = r.get_jnext(j)
if jnext is not None:
ltags = (r.site_tag(i, j), r.site_tag(inext, j))
rtags = (r.site_tag(i, jnext), r.site_tag(inext, jnext))
# │ │
# ──O─┐ chi ┌─O── i+1
# │ └─▷═◁─┘ │
Expand All @@ -1565,12 +1604,12 @@ def _contract_boundary_projector(
# contract each pair of boundary tensors with their projectors
for j in r.sweep_other:
self.contract_tags_(
(r.site_tag(i0, j), r.site_tag(i1, j)),
(r.site_tag(i, j), r.site_tag(inext, j)),
optimize=optimize,
)

if equalize_norms:
for t in self.select_tensors(r.x_tag(i1)):
for t in self.select_tensors(r.x_tag(inext)):
self.strip_exponent(t, equalize_norms)

def contract_boundary_from(
Expand Down Expand Up @@ -2205,9 +2244,16 @@ def _contract_interleaved_boundary_sequence(
if sequence is None:
# contract in both sides along short dimension -> less compression
if self.Lx >= self.Ly:
sequence = ("xmin", "xmax")

if self.is_cyclic_x():
sequence = ("xmin",)
else:
sequence = ("xmin", "xmax")
else:
sequence = ("ymin", "ymax")
if self.is_cyclic_y():
sequence = ("ymin",)
else:
sequence = ("ymin", "ymax")
else:
sequence = parse_boundary_sequence(sequence)

Expand Down Expand Up @@ -3251,16 +3297,19 @@ def coarse_grain_hotrg(
tn = self if inplace else self.copy()
tn_calc = tn.copy()

r = Rotator2D(tn, None, None, direction + "min")
cyclic_y = r.is_cyclic_y()
r = Rotator2D(tn, None, None, f"{direction}min")

# track new coordinates / tags
retag_map = {}

for i in range(r.imin, r.imax + 1, 2):
next_i_in_lattice = i + 1 <= r.imax
inext = i + 1
next_i_in_lattice = inext <= r.imax

for j in r.sweep_other:
# handles cyclic case
jnext = r.get_jnext(j)

for j in range(r.jmin, r.jmax + 1):
# │ │
# ──O─┐ chi ┌─O── i+1
# │ └─▷═◁─┘ │
Expand All @@ -3269,18 +3318,15 @@ def coarse_grain_hotrg(
# │ │
# j j+1
tag_ij = r.site_tag(i, j)
tag_ip1j = r.site_tag(i + 1, j)
tag_ip1j = r.site_tag(inext, j)
new_tag = r.site_tag(i // 2, j)
retag_map[tag_ij] = new_tag
if next_i_in_lattice:
retag_map[tag_ip1j] = new_tag

if next_i_in_lattice and ((j + 1 <= r.jmax) or cyclic_y):
if next_i_in_lattice and jnext is not None:
ltags = (tag_ij, tag_ip1j)

# handle cyclic case
jp1 = j + 1 if (j + 1 <= r.jmax) else r.jmin
rtags = r.site_tag(i, jp1), r.site_tag(i + 1, jp1)
rtags = r.site_tag(i, jnext), r.site_tag(inext, jnext)
tn_calc.insert_compressor_between_regions(
ltags,
rtags,
Expand All @@ -3294,7 +3340,7 @@ def coarse_grain_hotrg(

retag_map[r.x_tag(i)] = r.x_tag(i // 2)
if next_i_in_lattice:
retag_map[r.x_tag(i + 1)] = r.x_tag(i // 2)
retag_map[r.x_tag(inext)] = r.x_tag(i // 2)

# then we retag the tensor network and adjust its size
tn.retag_(retag_map)
Expand Down Expand Up @@ -3475,7 +3521,7 @@ def contract_ctmrg(
lazy=False,
mode="projector",
compress_opts=None,
sequence=("xmin", "xmax", "ymin", "ymax"),
sequence=None,
xmin=None,
xmax=None,
ymin=None,
Expand Down Expand Up @@ -3573,6 +3619,17 @@ def contract_ctmrg(
# we are implicitly asking for the tensor network
final_contract = False

if sequence is None:
sequence = []
if self.is_cyclic_x():
sequence.append("xmin")
else:
sequence.extend(["xmin", "xmax"])
if self.is_cyclic_y():
sequence.append("ymin")
else:
sequence.extend(["ymin", "ymax"])

return self._contract_interleaved_boundary_sequence(
contract_boundary_opts=contract_boundary_opts,
canonize=canonize,
Expand Down
Loading

0 comments on commit 5ba7724

Please sign in to comment.