Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
dallonasnes committed Oct 22, 2023
1 parent 0d74462 commit 55dda4f
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions scratchpad/tn_api/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
from dataclasses import dataclass
from typing import TypeVar, Generic, Iterable
from qtree import np_framework

class Array(np.ndarray):
shape: tuple
Expand Down Expand Up @@ -59,14 +58,23 @@ def __init__(self, *args, **kwargs):
self._tensors = []
self._edges = tuple()
self.shape = tuple()
self.buckets = []
self.data_dict = {}

# slice not inplace
def slice(self, slice_dict: dict) -> 'TensorNetwork':
tn = self.copy()
sliced_buckets = np_framework.get_sliced_np_buckets(self.buckets, self.data_dict, slice_dict)
tn.buckets = sliced_buckets
sliced_tns = []
for tensor in tn._tensors:
slice_bounds = []
for idx in range(tensor.ndim):
try:
slice_bounds.append(slice_dict[idx])
except KeyError:
slice_bounds.append(slice(None))

sliced_tns.append(tensor[tuple(slice_bounds)])

tn._tensors = sliced_tns
return tn

def copy(self):
Expand Down Expand Up @@ -161,4 +169,6 @@ def __repr__(self):

if __name__ == "__main__":
tn = TensorNetwork.new_random_cpu(2, 3, 4)
slice_dict = {0: slice(0, 2), 1: slice(1, 3)}
sliced_tn = tn.slice(slice_dict)
import pdb; pdb.set_trace()

0 comments on commit 55dda4f

Please sign in to comment.