diff --git a/lib/gpt/core/io/gpt_io.py b/lib/gpt/core/io/gpt_io.py index 6debd3ad..e04d0b54 100644 --- a/lib/gpt/core/io/gpt_io.py +++ b/lib/gpt/core/io/gpt_io.py @@ -362,6 +362,8 @@ def create_index(self, ctx, objs): f.write("}\n") elif isinstance(objs, numpy.ndarray): # needs to be above list for proper precedence f.write("array %d %d\n" % self.write_numpy(objs)) + elif isinstance(objs, gpt.tensor): + f.write("tensor %s %d %d\n" % (objs.describe(), *self.write_numpy(objs.array))) elif isinstance(objs, list): f.write("[\n") for i, x in enumerate(objs): @@ -453,6 +455,11 @@ def read_index(self, p, ctx=""): if not self.keep_context(ctx): return None return self.read_numpy(int(a[1]), int(a[2])) + elif cmd == "tensor": + a = p.get() # array start end + if not self.keep_context(ctx): + return None + return gpt.tensor(self.read_numpy(int(a[2]), int(a[3])), a[1]) elif cmd == "lattice": a = p.get() if not self.keep_context(ctx): diff --git a/lib/gpt/core/tensor.py b/lib/gpt/core/tensor.py index e52a6535..f7e2975b 100644 --- a/lib/gpt/core/tensor.py +++ b/lib/gpt/core/tensor.py @@ -28,6 +28,8 @@ class tensor(foundation_base): def __init__(self, first, second=None): if second is not None: array, otype = first, second + if isinstance(otype, str): + otype = gpt.str_to_otype(otype) else: otype = first array = np.zeros(otype.shape, dtype=np.complex128) @@ -43,6 +45,9 @@ def __init__(self, first, second=None): def __repr__(self): return "tensor(%s,%s)" % (str(self.array), self.otype.__name__) + def describe(self): + return self.otype.__name__ + def __getitem__(self, a): return self.array.__getitem__(a) diff --git a/tests/io/io.py b/tests/io/io.py index 95a08f73..63dac230 100755 --- a/tests/io/io.py +++ b/tests/io/io.py @@ -59,6 +59,7 @@ "U": U, # write list of lattices "sdomain": sdomain, "S": S, + "tu": U[0][1, 1, 1, 1], } g.save(f"{work_dir}/out", to_save) @@ -121,6 +122,10 @@ def check_all(res, tag): eps2 += g.norm2(a - b) assert eps2 < 1e-25 + # tensor test + eps2 = g.norm2(U[0][1, 1, 1, 1] - res["tu"]) + assert eps2 < 1e-25 + check_all(g.load(f"{work_dir}/out"), "original mpi geometry")