Skip to content

Commit

Permalink
add g.tensor to gpt_io
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 30, 2024
1 parent 27c007a commit 073edfe
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lib/gpt/core/io/gpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ def create_index(self, ctx, objs):
f.write("}\n")
elif isinstance(objs, numpy.ndarray): # needs to be above list for proper precedence
f.write("array %d %d\n" % self.write_numpy(objs))
elif isinstance(objs, gpt.tensor):
f.write("tensor %s %d %d\n" % (objs.describe(), *self.write_numpy(objs.array)))
elif isinstance(objs, list):
f.write("[\n")
for i, x in enumerate(objs):
Expand Down Expand Up @@ -453,6 +455,11 @@ def read_index(self, p, ctx=""):
if not self.keep_context(ctx):
return None
return self.read_numpy(int(a[1]), int(a[2]))
elif cmd == "tensor":
a = p.get() # array start end
if not self.keep_context(ctx):
return None
return gpt.tensor(self.read_numpy(int(a[2]), int(a[3])), a[1])
elif cmd == "lattice":
a = p.get()
if not self.keep_context(ctx):
Expand Down
5 changes: 5 additions & 0 deletions lib/gpt/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class tensor(foundation_base):
def __init__(self, first, second=None):
if second is not None:
array, otype = first, second
if isinstance(otype, str):
otype = gpt.str_to_otype(otype)
else:
otype = first
array = np.zeros(otype.shape, dtype=np.complex128)
Expand All @@ -43,6 +45,9 @@ def __init__(self, first, second=None):
def __repr__(self):
return "tensor(%s,%s)" % (str(self.array), self.otype.__name__)

def describe(self):
return self.otype.__name__

def __getitem__(self, a):
return self.array.__getitem__(a)

Expand Down
5 changes: 5 additions & 0 deletions tests/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"U": U, # write list of lattices
"sdomain": sdomain,
"S": S,
"tu": U[0][1, 1, 1, 1],
}

g.save(f"{work_dir}/out", to_save)
Expand Down Expand Up @@ -121,6 +122,10 @@ def check_all(res, tag):
eps2 += g.norm2(a - b)
assert eps2 < 1e-25

# tensor test
eps2 = g.norm2(U[0][1, 1, 1, 1] - res["tu"])
assert eps2 < 1e-25


check_all(g.load(f"{work_dir}/out"), "original mpi geometry")

Expand Down

0 comments on commit 073edfe

Please sign in to comment.