diff --git a/quimb/tensor/fitting.py b/quimb/tensor/fitting.py new file mode 100644 index 00000000..b11e6539 --- /dev/null +++ b/quimb/tensor/fitting.py @@ -0,0 +1,385 @@ +"""Tools for computing distances between and fitting tensor networks.""" +from autoray import dag, do + +from .contraction import contract_strategy +from ..utils import check_opt + + +def tensor_network_distance( + tnA, + tnB, + xAA=None, + xAB=None, + xBB=None, + method="auto", + normalized=False, + **contract_opts, +): + r"""Compute the Frobenius norm distance between two tensor networks: + + .. math:: + + D(A, B) + = | A - B |_{\mathrm{fro}} + = \mathrm{Tr} [(A - B)^{\dagger}(A - B)]^{1/2} + = ( \langle A | A \rangle - 2 \mathrm{Re} \langle A | B \rangle| + + \langle B | B \rangle ) ^{1/2} + + which should have matching outer indices. Note the default approach to + computing the norm is precision limited to about ``eps**0.5`` where ``eps`` + is the precision of the data type, e.g. ``1e-8`` for float64. This is due + to the subtraction in the above expression. + + Parameters + ---------- + tnA : TensorNetwork or Tensor + The first tensor network operator. + tnB : TensorNetwork or Tensor + The second tensor network operator. + xAA : None or scalar + The value of ``A.H @ A`` if you already know it (or it doesn't matter). + xAB : None or scalar + The value of ``A.H @ B`` if you already know it (or it doesn't matter). + xBB : None or scalar + The value of ``B.H @ B`` if you already know it (or it doesn't matter). + method : {'auto', 'overlap', 'dense'}, optional + How to compute the distance. If ``'overlap'``, the default, the + distance will be computed as the sum of overlaps, without explicitly + forming the dense operators. If ``'dense'``, the operators will be + directly formed and the norm computed, which can be quicker when the + exterior dimensions are small. If ``'auto'``, the dense method will + be used if the total operator (outer) size is ``<= 2**16``. + normalized : bool, optional + If ``True``, then normalize the distance by the norm of the two + operators, i.e. ``2 * D(A, B) / (|A| + |B|)``. The resulting distance + lies between 0 and 2 and is more useful for assessing convergence. + contract_opts + Supplied to :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`. + + Returns + ------- + D : float + """ + check_opt("method", method, ("auto", "dense", "overlap")) + + tnA = tnA.as_network() + tnB = tnB.as_network() + + oix = tnA.outer_inds() + if set(oix) != set(tnB.outer_inds()): + raise ValueError( + "Can only compute distance between tensor " + "networks with matching outer indices." + ) + + if method == "auto": + d = tnA.inds_size(oix) + if d <= 1 << 16: + method = "dense" + else: + method = "overlap" + + # directly from vectorizations of both + if method == "dense": + tnA = tnA.contract(..., output_inds=oix, preserve_tensor=True) + tnB = tnB.contract(..., output_inds=oix, preserve_tensor=True) + + # overlap method + if xAA is None: + xAA = (tnA | tnA.H).contract(..., **contract_opts) + if xAB is None: + xAB = (tnA | tnB.H).contract(..., **contract_opts) + if xBB is None: + xBB = (tnB | tnB.H).contract(..., **contract_opts) + + dAB = do("abs", xAA - 2 * do("real", xAB) + xBB) ** 0.5 + + if normalized: + dAB *= 2 / (do("abs", xAA)**0.5 + do("abs", xBB)**0.5) + + return dAB + + + + +def tensor_network_fit_autodiff( + tn, + tn_target, + steps=1000, + tol=1e-9, + autodiff_backend="autograd", + contract_optimize="auto-hq", + distance_method="auto", + inplace=False, + progbar=False, + **kwargs, +): + """Optimize the fit of ``tn`` with respect to ``tn_target`` using + automatic differentation. This minimizes the norm of the difference + between the two tensor networks, which must have matching outer indices, + using overlaps. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to fit. + tn_target : TensorNetwork + The target tensor network to fit ``tn`` to. + steps : int, optional + The maximum number of autodiff steps. + tol : float, optional + The target norm distance. + autodiff_backend : str, optional + Which backend library to use to perform the gradient computation. + contract_optimize : str, optional + The contraction path optimized used to contract the overlaps. + distance_method : {'auto', 'dense', 'overlap'}, optional + Supplied to :func:`~quimb.tensor.tensor_core.tensor_network_distance`, + controls how the distance is computed. + inplace : bool, optional + Update ``tn`` in place. + progbar : bool, optional + Show a live progress bar of the fitting process. + kwargs + Passed to :class:`~quimb.tensor.tensor_core.optimize.TNOptimizer`. + + See Also + -------- + tensor_network_distance, tensor_network_fit_als + """ + from .optimize import TNOptimizer + from .tensor_core import tensor_network_distance + + xBB = (tn_target | tn_target.H).contract( + ..., + output_inds=(), + optimize=contract_optimize, + ) + + tnopt = TNOptimizer( + tn=tn, + loss_fn=tensor_network_distance, + loss_constants={"tnB": tn_target, "xBB": xBB}, + loss_kwargs={"method": distance_method, "optimize": contract_optimize}, + autodiff_backend=autodiff_backend, + progbar=progbar, + **kwargs, + ) + + tn_fit = tnopt.optimize(steps, tol=tol) + + if not inplace: + return tn_fit + + for t1, t2 in zip(tn, tn_fit): + t1.modify(data=t2.data) + + return tn + + +def _tn_fit_als_core( + var_tags, + tnAA, + tnAB, + xBB, + tol, + contract_optimize, + steps, + enforce_pos, + pos_smudge, + solver="solve", + progbar=False, +): + from .tensor_core import group_inds + + # shared intermediates + greedy = good reuse of contractions + with contract_strategy(contract_optimize): + # prepare each of the contractions we are going to repeat + env_contractions = [] + for tg in var_tags: + # varying tensor and conjugate in norm + tk = tnAA["__KET__", tg] + tb = tnAA["__BRA__", tg] + + # get inds, and ensure any bonds come last, for linalg.solve + lix, bix, rix = group_inds(tb, tk) + tk.transpose_(*rix, *bix) + tb.transpose_(*lix, *bix) + + # form TNs with 'holes', i.e. environment tensors networks + A_tn = tnAA.select((tg,), "!all") + y_tn = tnAB.select((tg,), "!all") + + env_contractions.append((tk, tb, lix, bix, rix, A_tn, y_tn)) + + if tol != 0.0: + old_d = float("inf") + + if progbar: + import tqdm + + pbar = tqdm.trange(steps) + else: + pbar = range(steps) + + # the main iterative sweep on each tensor, locally optimizing + for _ in pbar: + for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions: + Ni = A_tn.to_dense(lix, rix) + Wi = y_tn.to_dense(rix, bix) + + if enforce_pos: + el, ev = do("linalg.eigh", Ni) + el = do("clip", el, el[-1] * pos_smudge, None) + Ni_p = ev * do("reshape", el, (1, -1)) @ dag(ev) + else: + Ni_p = Ni + + if solver == "solve": + x = do("linalg.solve", Ni_p, Wi) + elif solver == "lstsq": + x = do("linalg.lstsq", Ni_p, Wi, rcond=pos_smudge)[0] + + x_r = do("reshape", x, tk.shape) + # n.b. because we are using virtual TNs -> updates propagate + tk.modify(data=x_r) + tb.modify(data=do("conj", x_r)) + + # assess | A - B | for convergence or printing + if (tol != 0.0) or progbar: + xAA = do("trace", dag(x) @ (Ni @ x)) # + xAB = do("trace", do("real", dag(x) @ Wi)) # + d = do("abs", (xAA - 2 * xAB + xBB)) ** 0.5 + if abs(d - old_d) < tol: + break + old_d = d + + if progbar: + pbar.set_description(str(d)) + + +def tensor_network_fit_als( + tn, + tn_target, + tags=None, + steps=100, + tol=1e-9, + solver="solve", + enforce_pos=False, + pos_smudge=None, + tnAA=None, + tnAB=None, + xBB=None, + contract_optimize="greedy", + inplace=False, + progbar=False, +): + """Optimize the fit of ``tn`` with respect to ``tn_target`` using + alternating least squares (ALS). This minimizes the norm of the difference + between the two tensor networks, which must have matching outer indices, + using overlaps. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to fit. + tn_target : TensorNetwork + The target tensor network to fit ``tn`` to. + tags : sequence of str, optional + If supplied, only optimize tensors matching any of given tags. + steps : int, optional + The maximum number of ALS steps. + tol : float, optional + The target norm distance. + solver : {'solve', 'lstsq', ...}, optional + The underlying driver function used to solve the local minimization, + e.g. ``numpy.linalg.solve`` for ``'solve'`` with ``numpy`` backend. + enforce_pos : bool, optional + Whether to enforce positivity of the locally formed environments, + which can be more stable. + pos_smudge : float, optional + If enforcing positivity, the level below which to clip eigenvalues + for make the local environment positive definite. + tnAA : TensorNetwork, optional + If you have already formed the overlap ``tn.H & tn``, maybe + approximately, you can supply it here. The unconjugated layer should + have tag ``'__KET__'`` and the conjugated layer ``'__BRA__'``. Each + tensor being optimized should have tag ``'__VAR{i}__'``. + tnAB : TensorNetwork, optional + If you have already formed the overlap ``tn_target.H & tn``, maybe + approximately, you can supply it here. Each tensor being optimized + should have tag ``'__VAR{i}__'``. + xBB : float, optional + If you have already know, have computed ``tn_target.H @ tn_target``, + or it doesn't matter, you can supply the value here. + contract_optimize : str, optional + The contraction path optimized used to contract the local environments. + Note ``'greedy'`` is the default in order to maximize shared work. + inplace : bool, optional + Update ``tn`` in place. + progbar : bool, optional + Show a live progress bar of the fitting process. + + Returns + ------- + TensorNetwork + + See Also + -------- + tensor_network_fit_autodiff, tensor_network_distance + """ + # mark the tensors we are going to optimize + tna = tn.copy() + tna.add_tag("__KET__") + + if tags is None: + to_tag = tna + else: + to_tag = tna.select_tensors(tags, "any") + + var_tags = [] + for i, t in enumerate(to_tag): + var_tag = f"__VAR{i}__" + t.add_tag(var_tag) + var_tags.append(var_tag) + + # form the norm of the varying TN (A) and its overlap with the target (B) + if tnAA is None: + tnAA = tna | tna.H.retag_({"__KET__": "__BRA__"}) + if tnAB is None: + tnAB = tna | tn_target.H + + if (tol != 0.0) and (xBB is None): + # + xBB = (tn_target | tn_target.H).contract( + ..., + optimize=contract_optimize, + output_inds=(), + ) + + if pos_smudge is None: + pos_smudge = max(tol, 1e-15) + + _tn_fit_als_core( + var_tags=var_tags, + tnAA=tnAA, + tnAB=tnAB, + xBB=xBB, + tol=tol, + contract_optimize=contract_optimize, + steps=steps, + enforce_pos=enforce_pos, + pos_smudge=pos_smudge, + solver=solver, + progbar=progbar, + ) + + if not inplace: + tn = tn.copy() + + for t1, t2 in zip(tn, tna): + # transpose so only thing changed in original TN is data + t2.transpose_like_(t1) + t1.modify(data=t2.data) + + return tn diff --git a/quimb/tensor/tensor_2d_tebd.py b/quimb/tensor/tensor_2d_tebd.py index d24467f4..aff190d8 100644 --- a/quimb/tensor/tensor_2d_tebd.py +++ b/quimb/tensor/tensor_2d_tebd.py @@ -5,23 +5,24 @@ import numpy as np import scipy.sparse.linalg as spla -from autoray import do, dag, conj, reshape +from autoray import conj, dag, do, reshape -from ..utils import pairwise, default_to_neutral_style +from ..utils import default_to_neutral_style, pairwise +from .contraction import contract_strategy from .drawing import get_colors -from .tensor_core import Tensor, contract_strategy from .optimize import TNOptimizer from .tensor_2d import ( - gen_2d_bonds, - calc_plaquette_sizes, calc_plaquette_map, - plaquette_to_sites, + calc_plaquette_sizes, + gen_2d_bonds, gen_long_range_path, gen_long_range_swap_path, - swap_path_to_long_range_path, nearest_neighbors, + plaquette_to_sites, + swap_path_to_long_range_path, ) from .tensor_arbgeom_tebd import LocalHamGen, TEBDGen +from .tensor_core import Tensor class LocalHam2D(LocalHamGen): diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index 9148fe67..4648f649 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -1,16 +1,16 @@ """Core tensor network tools. """ -import os +import collections +import contextlib import copy -import uuid +import functools +import itertools import math +import operator +import os import string +import uuid import weakref -import operator -import functools -import itertools -import contextlib -import collections from numbers import Integral import numpy as np @@ -31,7 +31,8 @@ except ImportError: from ..core import common_type as get_common_dtype -from ..core import qarray, prod, realify_scalar, vdot, make_immutable +from ..core import make_immutable, prod, qarray, realify_scalar, vdot +from ..gen.rand import rand_matrix, rand_uni, randn, seed_rand from ..utils import ( check_opt, concat, @@ -44,9 +45,9 @@ unique, valmap, ) -from ..gen.rand import randn, seed_rand, rand_matrix, rand_uni from . import decomp from .array_ops import ( + PArray, asarray, find_antidiag_axes, find_columns, @@ -54,22 +55,13 @@ iscomplex, ndim, norm_fro, - PArray, -) -from .drawing import ( - auto_color_html, - draw_tn, - visualize_tensor, - visualize_tensors, ) - from .contraction import ( + array_contract, array_contract_expression, array_contract_path, array_contract_pathinfo, array_contract_tree, - array_contract, - contract_strategy, get_contract_backend, get_contract_strategy, get_symbol, @@ -77,6 +69,17 @@ inds_to_eq, inds_to_symbols, ) +from .drawing import ( + auto_color_html, + draw_tn, + visualize_tensor, + visualize_tensors, +) +from .fitting import ( + tensor_network_distance, + tensor_network_fit_als, + tensor_network_fit_autodiff, +) _inds_to_eq = deprecated(inds_to_eq, "_inds_to_eq", "inds_to_eq") get_symbol = deprecated( @@ -1287,381 +1290,6 @@ def maybe_unwrap( return maybe_realify_scalar(t.data) -def tensor_network_distance( - tnA, - tnB, - xAA=None, - xAB=None, - xBB=None, - method="auto", - normalized=False, - **contract_opts, -): - r"""Compute the Frobenius norm distance between two tensor networks: - - .. math:: - - D(A, B) - = | A - B |_{\mathrm{fro}} - = \mathrm{Tr} [(A - B)^{\dagger}(A - B)]^{1/2} - = ( \langle A | A \rangle - 2 \mathrm{Re} \langle A | B \rangle| - + \langle B | B \rangle ) ^{1/2} - - which should have matching outer indices. Note the default approach to - computing the norm is precision limited to about ``eps**0.5`` where ``eps`` - is the precision of the data type, e.g. ``1e-8`` for float64. This is due - to the subtraction in the above expression. - - Parameters - ---------- - tnA : TensorNetwork or Tensor - The first tensor network operator. - tnB : TensorNetwork or Tensor - The second tensor network operator. - xAA : None or scalar - The value of ``A.H @ A`` if you already know it (or it doesn't matter). - xAB : None or scalar - The value of ``A.H @ B`` if you already know it (or it doesn't matter). - xBB : None or scalar - The value of ``B.H @ B`` if you already know it (or it doesn't matter). - method : {'auto', 'overlap', 'dense'}, optional - How to compute the distance. If ``'overlap'``, the default, the - distance will be computed as the sum of overlaps, without explicitly - forming the dense operators. If ``'dense'``, the operators will be - directly formed and the norm computed, which can be quicker when the - exterior dimensions are small. If ``'auto'``, the dense method will - be used if the total operator (outer) size is ``<= 2**16``. - normalized : bool, optional - If ``True``, then normalize the distance by the norm of the two - operators, i.e. ``2 * D(A, B) / (|A| + |B|)``. The resulting distance - lies between 0 and 2 and is more useful for assessing convergence. - contract_opts - Supplied to :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`. - - Returns - ------- - D : float - """ - check_opt("method", method, ("auto", "dense", "overlap")) - - tnA = tnA.as_network() - tnB = tnB.as_network() - - oix = tnA.outer_inds() - if set(oix) != set(tnB.outer_inds()): - raise ValueError( - "Can only compute distance between tensor " - "networks with matching outer indices." - ) - - if method == "auto": - d = tnA.inds_size(oix) - if d <= 1 << 16: - method = "dense" - else: - method = "overlap" - - # directly from vectorizations of both - if method == "dense": - tnA = tnA.contract(..., output_inds=oix, preserve_tensor=True) - tnB = tnB.contract(..., output_inds=oix, preserve_tensor=True) - - # overlap method - if xAA is None: - xAA = (tnA | tnA.H).contract(..., **contract_opts) - if xAB is None: - xAB = (tnA | tnB.H).contract(..., **contract_opts) - if xBB is None: - xBB = (tnB | tnB.H).contract(..., **contract_opts) - - dAB = do("abs", xAA - 2 * do("real", xAB) + xBB) ** 0.5 - - if normalized: - dAB *= 2 / (do("abs", xAA)**0.5 + do("abs", xBB)**0.5) - - return dAB - - -def tensor_network_fit_autodiff( - tn, - tn_target, - steps=1000, - tol=1e-9, - autodiff_backend="autograd", - contract_optimize="auto-hq", - distance_method="auto", - inplace=False, - progbar=False, - **kwargs, -): - """Optimize the fit of ``tn`` with respect to ``tn_target`` using - automatic differentation. This minimizes the norm of the difference - between the two tensor networks, which must have matching outer indices, - using overlaps. - - Parameters - ---------- - tn : TensorNetwork - The tensor network to fit. - tn_target : TensorNetwork - The target tensor network to fit ``tn`` to. - steps : int, optional - The maximum number of autodiff steps. - tol : float, optional - The target norm distance. - autodiff_backend : str, optional - Which backend library to use to perform the gradient computation. - contract_optimize : str, optional - The contraction path optimized used to contract the overlaps. - distance_method : {'auto', 'dense', 'overlap'}, optional - Supplied to :func:`~quimb.tensor.tensor_core.tensor_network_distance`, - controls how the distance is computed. - inplace : bool, optional - Update ``tn`` in place. - progbar : bool, optional - Show a live progress bar of the fitting process. - kwargs - Passed to :class:`~quimb.tensor.tensor_core.optimize.TNOptimizer`. - - See Also - -------- - tensor_network_distance, tensor_network_fit_als - """ - from .optimize import TNOptimizer - - xBB = (tn_target | tn_target.H).contract( - ..., - output_inds=(), - optimize=contract_optimize, - ) - - tnopt = TNOptimizer( - tn=tn, - loss_fn=tensor_network_distance, - loss_constants={"tnB": tn_target, "xBB": xBB}, - loss_kwargs={"method": distance_method, "optimize": contract_optimize}, - autodiff_backend=autodiff_backend, - progbar=progbar, - **kwargs, - ) - - tn_fit = tnopt.optimize(steps, tol=tol) - - if not inplace: - return tn_fit - - for t1, t2 in zip(tn, tn_fit): - t1.modify(data=t2.data) - - return tn - - -def _tn_fit_als_core( - var_tags, - tnAA, - tnAB, - xBB, - tol, - contract_optimize, - steps, - enforce_pos, - pos_smudge, - solver="solve", - progbar=False, -): - # shared intermediates + greedy = good reuse of contractions - with contract_strategy(contract_optimize): - # prepare each of the contractions we are going to repeat - env_contractions = [] - for tg in var_tags: - # varying tensor and conjugate in norm - tk = tnAA["__KET__", tg] - tb = tnAA["__BRA__", tg] - - # get inds, and ensure any bonds come last, for linalg.solve - lix, bix, rix = group_inds(tb, tk) - tk.transpose_(*rix, *bix) - tb.transpose_(*lix, *bix) - - # form TNs with 'holes', i.e. environment tensors networks - A_tn = tnAA.select((tg,), "!all") - y_tn = tnAB.select((tg,), "!all") - - env_contractions.append((tk, tb, lix, bix, rix, A_tn, y_tn)) - - if tol != 0.0: - old_d = float("inf") - - if progbar: - import tqdm - - pbar = tqdm.trange(steps) - else: - pbar = range(steps) - - # the main iterative sweep on each tensor, locally optimizing - for _ in pbar: - for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions: - Ni = A_tn.to_dense(lix, rix) - Wi = y_tn.to_dense(rix, bix) - - if enforce_pos: - el, ev = do("linalg.eigh", Ni) - el = do("clip", el, el[-1] * pos_smudge, None) - Ni_p = ev * do("reshape", el, (1, -1)) @ dag(ev) - else: - Ni_p = Ni - - if solver == "solve": - x = do("linalg.solve", Ni_p, Wi) - elif solver == "lstsq": - x = do("linalg.lstsq", Ni_p, Wi, rcond=pos_smudge)[0] - - x_r = do("reshape", x, tk.shape) - # n.b. because we are using virtual TNs -> updates propagate - tk.modify(data=x_r) - tb.modify(data=do("conj", x_r)) - - # assess | A - B | for convergence or printing - if (tol != 0.0) or progbar: - xAA = do("trace", dag(x) @ (Ni @ x)) # - xAB = do("trace", do("real", dag(x) @ Wi)) # - d = do("abs", (xAA - 2 * xAB + xBB)) ** 0.5 - if abs(d - old_d) < tol: - break - old_d = d - - if progbar: - pbar.set_description(str(d)) - - -def tensor_network_fit_als( - tn, - tn_target, - tags=None, - steps=100, - tol=1e-9, - solver="solve", - enforce_pos=False, - pos_smudge=None, - tnAA=None, - tnAB=None, - xBB=None, - contract_optimize="greedy", - inplace=False, - progbar=False, -): - """Optimize the fit of ``tn`` with respect to ``tn_target`` using - alternating least squares (ALS). This minimizes the norm of the difference - between the two tensor networks, which must have matching outer indices, - using overlaps. - - Parameters - ---------- - tn : TensorNetwork - The tensor network to fit. - tn_target : TensorNetwork - The target tensor network to fit ``tn`` to. - tags : sequence of str, optional - If supplied, only optimize tensors matching any of given tags. - steps : int, optional - The maximum number of ALS steps. - tol : float, optional - The target norm distance. - solver : {'solve', 'lstsq', ...}, optional - The underlying driver function used to solve the local minimization, - e.g. ``numpy.linalg.solve`` for ``'solve'`` with ``numpy`` backend. - enforce_pos : bool, optional - Whether to enforce positivity of the locally formed environments, - which can be more stable. - pos_smudge : float, optional - If enforcing positivity, the level below which to clip eigenvalues - for make the local environment positive definite. - tnAA : TensorNetwork, optional - If you have already formed the overlap ``tn.H & tn``, maybe - approximately, you can supply it here. The unconjugated layer should - have tag ``'__KET__'`` and the conjugated layer ``'__BRA__'``. Each - tensor being optimized should have tag ``'__VAR{i}__'``. - tnAB : TensorNetwork, optional - If you have already formed the overlap ``tn_target.H & tn``, maybe - approximately, you can supply it here. Each tensor being optimized - should have tag ``'__VAR{i}__'``. - xBB : float, optional - If you have already know, have computed ``tn_target.H @ tn_target``, - or it doesn't matter, you can supply the value here. - contract_optimize : str, optional - The contraction path optimized used to contract the local environments. - Note ``'greedy'`` is the default in order to maximize shared work. - inplace : bool, optional - Update ``tn`` in place. - progbar : bool, optional - Show a live progress bar of the fitting process. - - Returns - ------- - TensorNetwork - - See Also - -------- - tensor_network_fit_autodiff, tensor_network_distance - """ - # mark the tensors we are going to optimize - tna = tn.copy() - tna.add_tag("__KET__") - - if tags is None: - to_tag = tna - else: - to_tag = tna.select_tensors(tags, "any") - - var_tags = [] - for i, t in enumerate(to_tag): - var_tag = f"__VAR{i}__" - t.add_tag(var_tag) - var_tags.append(var_tag) - - # form the norm of the varying TN (A) and its overlap with the target (B) - if tnAA is None: - tnAA = tna | tna.H.retag_({"__KET__": "__BRA__"}) - if tnAB is None: - tnAB = tna | tn_target.H - - if (tol != 0.0) and (xBB is None): - # - xBB = (tn_target | tn_target.H).contract( - ..., - optimize=contract_optimize, - output_inds=(), - ) - - if pos_smudge is None: - pos_smudge = max(tol, 1e-15) - - _tn_fit_als_core( - var_tags=var_tags, - tnAA=tnAA, - tnAB=tnAB, - xBB=xBB, - tol=tol, - contract_optimize=contract_optimize, - steps=steps, - enforce_pos=enforce_pos, - pos_smudge=pos_smudge, - solver=solver, - progbar=progbar, - ) - - if not inplace: - tn = tn.copy() - - for t1, t2 in zip(tn, tna): - # transpose so only thing changed in original TN is data - t2.transpose_like_(t1) - t1.modify(data=t2.data) - - return tn - - # --------------------------------------------------------------------------- # # Tensor Class # # --------------------------------------------------------------------------- # @@ -4028,7 +3656,8 @@ def __init__(self, ts=(), *, virtual=False, check_collisions=True): def combine(self, other, *, virtual=False, check_collisions=True): """Combine this tensor network with another, returning a new tensor - network. + network. This can be overriden by subclasses to check for a compatible + structured type. Parameters ---------- @@ -4837,8 +4466,8 @@ def geometry_hash(self, output_inds=None, strict_index_order=False): """ - import pickle import hashlib + import pickle inputs, output, size_dict = self.get_inputs_output_size_dict( output_inds=output_inds,