diff --git a/scratchpad/tn_api/tn.py b/scratchpad/tn_api/tn.py index 0acbefe1..078c3cb8 100644 --- a/scratchpad/tn_api/tn.py +++ b/scratchpad/tn_api/tn.py @@ -2,7 +2,6 @@ import math from dataclasses import dataclass from typing import TypeVar, Generic, Iterable -from qtree import np_framework class Array(np.ndarray): shape: tuple @@ -59,14 +58,23 @@ def __init__(self, *args, **kwargs): self._tensors = [] self._edges = tuple() self.shape = tuple() - self.buckets = [] self.data_dict = {} # slice not inplace def slice(self, slice_dict: dict) -> 'TensorNetwork': tn = self.copy() - sliced_buckets = np_framework.get_sliced_np_buckets(self.buckets, self.data_dict, slice_dict) - tn.buckets = sliced_buckets + sliced_tns = [] + for tensor in tn._tensors: + slice_bounds = [] + for idx in range(tensor.ndim): + try: + slice_bounds.append(slice_dict[idx]) + except KeyError: + slice_bounds.append(slice(None)) + + sliced_tns.append(tensor[tuple(slice_bounds)]) + + tn._tensors = sliced_tns return tn def copy(self): @@ -161,4 +169,6 @@ def __repr__(self): if __name__ == "__main__": tn = TensorNetwork.new_random_cpu(2, 3, 4) + slice_dict = {0: slice(0, 2), 1: slice(1, 3)} + sliced_tn = tn.slice(slice_dict) import pdb; pdb.set_trace() \ No newline at end of file