From 914efb367a7b627dcaff7d3e9fb78ecb3d011cca Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Sat, 23 Mar 2024 13:23:36 +0100 Subject: [PATCH] first round of improvements to type system; thanks to Daniel Knuettel for the joint effort --- lib/gpt/ad/reverse/util.py | 5 +- lib/gpt/core/coordinates.py | 2 +- lib/gpt/core/expr.py | 161 +++++++++--------- lib/gpt/core/lattice.py | 4 +- lib/gpt/core/object_type/__init__.py | 2 +- lib/gpt/core/object_type/base.py | 3 + .../object_type/complex_additive_group.py | 10 +- lib/gpt/core/object_type/container.py | 23 +-- .../core/object_type/real_additive_group.py | 10 +- lib/gpt/core/object_type/u_1.py | 6 +- lib/gpt/core/tensor.py | 2 +- lib/gpt/core/util.py | 2 +- lib/gpt/create/sparse_grid.py | 2 +- lib/gpt/ml/layer/nearest_neighbor.py | 2 +- 14 files changed, 120 insertions(+), 114 deletions(-) diff --git a/lib/gpt/ad/reverse/util.py b/lib/gpt/ad/reverse/util.py index 0993a2597..48839dd7b 100644 --- a/lib/gpt/ad/reverse/util.py +++ b/lib/gpt/ad/reverse/util.py @@ -25,7 +25,7 @@ def is_field(x): elif isinstance(x, g.tensor): return False elif isinstance(x, g.expr): - return x.lattice() is not None + return x.container()[0] is not None elif g.util.is_num(x): return False elif isinstance(x, g.ad.forward.series): @@ -48,7 +48,8 @@ def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None): rhs_gradient = g(rhs_gradient) if isinstance(lhs_gradient, g.lattice) and isinstance(rhs_gradient, g.expr): - rhs_otype = rhs_gradient.lattice().otype + grid, rhs_otype, is_list, nlist = rhs_gradient.container() + assert not is_list # for now lhs_otype = lhs_gradient.otype if lhs_otype.__name__ != rhs_otype.__name__: if rhs_otype.spintrace[2] is not None: diff --git a/lib/gpt/core/coordinates.py b/lib/gpt/core/coordinates.py index abb03a758..d402ef906 100644 --- a/lib/gpt/core/coordinates.py +++ b/lib/gpt/core/coordinates.py @@ -131,7 +131,7 @@ def mat_backward(dst, src): def coordinate_mask(field, mask): assert isinstance(mask, numpy.ndarray) - assert field.otype.data_otype() == gpt.ot_singlet + assert isinstance(field.otype.data_otype(), gpt.ot_singlet) x = gpt.coordinates(field) field[x] = mask.astype(field.grid.precision.complex_dtype).reshape((len(mask), 1)) diff --git a/lib/gpt/core/expr.py b/lib/gpt/core/expr.py index bfd7ba994..9e95f38b0 100644 --- a/lib/gpt/core/expr.py +++ b/lib/gpt/core/expr.py @@ -37,6 +37,73 @@ class expr_unary: BIT_COLORTRACE = 2 +def get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj): + if f_adj and not t_adj and f_otype.itab is not None: + # inner + tab = f_otype.itab + rtab = {} + elif t_adj and not f_adj and f_otype.otab is not None: + # outer + tab = f_otype.otab + rtab = {} + else: + tab = f_otype.mtab + rtab = t_otype.rmtab + + if t_otype.__name__ in tab: + return tab[t_otype.__name__][0]() + else: + if f_otype.__name__ not in rtab: + if f_otype.data_alias is not None: + return get_otype_from_multiplication(t_otype, t_adj, f_otype.data_alias(), f_adj) + elif t_otype.data_alias is not None: + return get_otype_from_multiplication(t_otype.data_alias(), t_adj, f_otype, f_adj) + else: + ajd_str_t = ".H" if t_adj else "" + ajd_str_f = ".H" if f_adj else "" + gpt.message( + f"Missing entry in multiplication table: {t_otype.__name__}{ajd_str_t} x {f_otype.__name__}{ajd_str_f}" + ) + return rtab[f_otype.__name__][0]() + + +def get_otype_from_expression(e): + bare_otype = None + for coef, term in e.val: + if len(term) == 0: + t_otype = gpt.ot_singlet() + else: + t_otype = None + t_adj = False + for unary, factor in reversed(term): + f_otype = gpt.util.to_list(factor)[0].otype + f_adj = unary == factor_unary.ADJ and not f_otype.is_self_dual() + if t_otype is None: + t_otype = f_otype + t_adj = f_adj + else: + t_otype = get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj) + + if bare_otype is None: + bare_otype = t_otype + else: + # all elements of a sum must have same data type + assert t_otype.data_otype().__name__ == bare_otype.data_otype().__name__ + + # apply unaries + if e.unary & expr_unary.BIT_SPINTRACE: + st = bare_otype.spintrace + assert st is not None + if st[2] is not None: + bare_otype = st[2]() + if e.unary & expr_unary.BIT_COLORTRACE: + ct = bare_otype.colortrace + assert ct is not None + if ct[2] is not None: + bare_otype = ct[2]() + return bare_otype + + # expr: # - each expression can have a unary operation such as trace # - each expression has linear combination of terms @@ -84,12 +151,19 @@ def is_num(self): def get_num(self): return self.val[0][0] - def lattice(self): + def container(self): + # returns triple grid, otype, n ; n can be None or an integer for a list of length n for v in self.val: for i in v[1]: if gpt.util.is_list_instance(i[1], gpt.lattice): - return i[1] - return None + representative = i[1] + return_list = isinstance(representative, list) + representative = gpt.util.to_list(representative) + grid = representative[0].grid + n = len(representative) + otype = get_otype_from_expression(self) + return grid, otype, return_list, n + return None, None, None, None def __mul__(self, l): if isinstance(l, expr): @@ -113,7 +187,7 @@ def __mul__(self, l): if uf & gpt.factor_unary.BIT_CONJ != 0: lhs = lhs.conj() res = gpt.tensor(np.tensordot(lhs, l.array, axes=mt[1]), mt[0]()) - if res.otype == gpt.ot_singlet: + if isinstance(res.otype, gpt.ot_singlet): res = complex(res.array) return res assert 0 @@ -214,73 +288,6 @@ def apply_type_right_to_left(e, t): assert 0 -def get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj): - if f_adj and not t_adj and f_otype.itab is not None: - # inner - tab = f_otype.itab - rtab = {} - elif t_adj and not f_adj and f_otype.otab is not None: - # outer - tab = f_otype.otab - rtab = {} - else: - tab = f_otype.mtab - rtab = t_otype.rmtab - - if t_otype.__name__ in tab: - return tab[t_otype.__name__][0]() - else: - if f_otype.__name__ not in rtab: - if f_otype.data_alias is not None: - return get_otype_from_multiplication(t_otype, t_adj, f_otype.data_alias(), f_adj) - elif t_otype.data_alias is not None: - return get_otype_from_multiplication(t_otype.data_alias(), t_adj, f_otype, f_adj) - else: - ajd_str_t = ".H" if t_adj else "" - ajd_str_f = ".H" if f_adj else "" - gpt.message( - f"Missing entry in multiplication table: {t_otype.__name__}{ajd_str_t} x {f_otype.__name__}{ajd_str_f}" - ) - return rtab[f_otype.__name__][0]() - - -def get_otype_from_expression(e): - bare_otype = None - for coef, term in e.val: - if len(term) == 0: - t_otype = gpt.ot_singlet - else: - t_otype = None - t_adj = False - for unary, factor in reversed(term): - f_otype = gpt.util.to_list(factor)[0].otype - f_adj = unary == factor_unary.ADJ - if t_otype is None: - t_otype = f_otype - t_adj = f_adj - else: - t_otype = get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj) - - if bare_otype is None: - bare_otype = t_otype - else: - # all elements of a sum must have same data type - assert t_otype.data_otype().__name__ == bare_otype.data_otype().__name__ - - # apply unaries - if e.unary & expr_unary.BIT_SPINTRACE: - st = bare_otype.spintrace - assert st is not None - if st[2] is not None: - bare_otype = st[2]() - if e.unary & expr_unary.BIT_COLORTRACE: - ct = bare_otype.colortrace - assert ct is not None - if ct[2] is not None: - bare_otype = ct[2]() - return bare_otype - - def expr_eval(first, second=None, ac=False): t = gpt.timer("eval", verbose_performance) @@ -316,14 +323,10 @@ def expr_eval(first, second=None, ac=False): t("prepare") if dst is None: - lat = e.lattice() - if lat is None: + grid, otype, return_list, nlat = e.container() + if grid is None: # cannot evaluate to a lattice object, leave expression unevaluated return first - return_list = isinstance(lat, list) - lat = gpt.util.to_list(lat) - grid = lat[0].grid - nlat = len(lat) # verbose output if verbose: @@ -339,12 +342,8 @@ def expr_eval(first, second=None, ac=False): ret = dst else: assert ac is False - t("get otype") - # now find return type - otype = get_otype_from_expression(e) ret = [] - for idx in range(nlat): t("cgpt.eval") res = cgpt.eval(None, e.val, e.unary, False, idx) diff --git a/lib/gpt/core/lattice.py b/lib/gpt/core/lattice.py index 10eeae6c6..ea3788472 100644 --- a/lib/gpt/core/lattice.py +++ b/lib/gpt/core/lattice.py @@ -288,8 +288,8 @@ def __itruediv__(self, expr): return self def __lt__(self, other): - assert self.otype.data_otype() == gpt.ot_singlet - assert other.otype.data_otype() == gpt.ot_singlet + assert isinstance(self.otype.data_otype(), gpt.ot_singlet) + assert isinstance(other.otype.data_otype(), gpt.ot_singlet) res = gpt.lattice(self) params = {"operator": "<"} cgpt.binary(res.v_obj[0], self.v_obj[0], other.v_obj[0], params) diff --git a/lib/gpt/core/object_type/__init__.py b/lib/gpt/core/object_type/__init__.py index f7b4c3a28..e90f5045e 100644 --- a/lib/gpt/core/object_type/__init__.py +++ b/lib/gpt/core/object_type/__init__.py @@ -38,7 +38,7 @@ def gpt_object(first, ot): ### # Container objects without (lie) group structure def singlet(grid): - return gpt_object(grid, ot_singlet) + return gpt_object(grid, ot_singlet()) def matrix_color(grid, ndim): diff --git a/lib/gpt/core/object_type/base.py b/lib/gpt/core/object_type/base.py index 9f60fe8da..45d1b2549 100644 --- a/lib/gpt/core/object_type/base.py +++ b/lib/gpt/core/object_type/base.py @@ -42,3 +42,6 @@ def data_otype(self): if self.data_alias is not None: return self.data_alias() return self + + def is_self_dual(self): + return False diff --git a/lib/gpt/core/object_type/complex_additive_group.py b/lib/gpt/core/object_type/complex_additive_group.py index d1ffaac20..f7889e91a 100644 --- a/lib/gpt/core/object_type/complex_additive_group.py +++ b/lib/gpt/core/object_type/complex_additive_group.py @@ -30,13 +30,13 @@ class ot_complex_additive_group(ot_singlet): def __init__(self): self.__name__ = "ot_complex_additive_group" - self.data_alias = lambda: ot_singlet + self.data_alias = lambda: ot_singlet() self.rmtab = { - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } self.mtab = { self.__name__: (lambda: self, None), - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } # this is always multiplicative identity, not neutral element of group @@ -87,7 +87,7 @@ def __init__(self, n): } self.otab = {self.__name__: (lambda: ot_matrix_complex_additive_group(n), [])} self.itab = { - self.__name__: (lambda: ot_singlet, (0, 0)), + self.__name__: (lambda: ot_singlet(), (0, 0)), } self.cache = {} @@ -119,7 +119,7 @@ def coordinates(self, l, c=None): assert l.otype.__name__ == self.__name__ if c is None: r = [None] * self.shape[0] * 2 - a = gpt.separate_indices(l, (0, lambda: ot_singlet), self.cache) + a = gpt.separate_indices(l, (0, lambda: ot_singlet()), self.cache) for i in a: r[i[0]] = gpt.component.real(a[i]) r[i[0] + self.shape[0]] = gpt.component.imag(a[i]) diff --git a/lib/gpt/core/object_type/container.py b/lib/gpt/core/object_type/container.py index 6dfca4937..ff110d0fa 100644 --- a/lib/gpt/core/object_type/container.py +++ b/lib/gpt/core/object_type/container.py @@ -65,11 +65,14 @@ class ot_singlet(ot_base): colortrace = (None, None, None) v_otype = ["ot_singlet"] mtab = { - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } - def data_otype(self=None): - return ot_singlet + def data_otype(self): + return ot_singlet() + + def is_self_dual(self): + return True def identity(): return 1.0 @@ -84,7 +87,7 @@ def __init__(self, ndim): self.shape = (ndim, ndim) self.transposed = (1, 0) self.spintrace = (None, None, None) # do nothing - self.colortrace = (0, 1, lambda: ot_singlet) + self.colortrace = (0, 1, lambda: ot_singlet()) self.v_otype = ["ot_mcolor%d" % ndim] # cgpt data types self.mtab = { self.__name__: (lambda: self, (1, 0)), @@ -115,7 +118,7 @@ def __init__(self, ndim): } self.otab = {self.__name__: (lambda: ot_matrix_color(ndim), [])} self.itab = { - self.__name__: (lambda: ot_singlet, (0, 0)), + self.__name__: (lambda: ot_singlet(), (0, 0)), } def compose(self, a, b): @@ -133,7 +136,7 @@ def __init__(self, ndim): self.nfloats = 2 * ndim * ndim self.shape = (ndim, ndim) self.transposed = (1, 0) - self.spintrace = (0, 1, lambda: ot_singlet) + self.spintrace = (0, 1, lambda: ot_singlet()) self.colortrace = (None, None, None) # do nothing self.v_otype = ["ot_mspin%d" % ndim] self.mtab = { @@ -197,7 +200,7 @@ def __init__(self, ndim): "ot_singlet": (lambda: self, None), } self.otab = {self.__name__: (lambda: ot_matrix_spin(ndim), [])} - self.itab = {self.__name__: (lambda: ot_singlet, (0, 0))} + self.itab = {self.__name__: (lambda: ot_singlet(), (0, 0))} def compose(self, a, b): return a + b @@ -269,7 +272,7 @@ def __init__(self, spin_ndim, color_ndim): ), } self.itab = { - self.__name__: (lambda: ot_singlet, ([0, 1], [0, 1])), + self.__name__: (lambda: ot_singlet(), ([0, 1], [0, 1])), } self.mtab = { "ot_singlet": (lambda: self, None), @@ -346,7 +349,7 @@ def __init__(self, n): "ot_singlet": (lambda: self, None), } self.itab = { - self.__name__: (lambda: ot_singlet, (0, 0)), + self.__name__: (lambda: ot_singlet(), (0, 0)), } @@ -367,7 +370,7 @@ def __init__(self, n): self.shape = (n, n) self.transposed = (1, 0) self.spintrace = (None, None, None) - self.colortrace = (0, 1, lambda: ot_singlet) + self.colortrace = (0, 1, lambda: ot_singlet()) self.vector_type = ot_vector_singlet(n) self.mtab = { self.__name__: (lambda: self, (1, 0)), diff --git a/lib/gpt/core/object_type/real_additive_group.py b/lib/gpt/core/object_type/real_additive_group.py index 7cb5dcb20..e5b29e5a7 100644 --- a/lib/gpt/core/object_type/real_additive_group.py +++ b/lib/gpt/core/object_type/real_additive_group.py @@ -30,13 +30,13 @@ class ot_real_additive_group(ot_singlet): def __init__(self): self.__name__ = "ot_real_additive_group" - self.data_alias = lambda: ot_singlet + self.data_alias = lambda: ot_singlet() self.rmtab = { - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } self.mtab = { self.__name__: (lambda: self, None), - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } # this is always multiplicative identity, not neutral element of group @@ -87,7 +87,7 @@ def __init__(self, n): } self.otab = {self.__name__: (lambda: ot_matrix_real_additive_group(n), [])} self.itab = { - self.__name__: (lambda: ot_singlet, (0, 0)), + self.__name__: (lambda: ot_singlet(), (0, 0)), } self.cache = {} @@ -114,7 +114,7 @@ def coordinates(self, l, c=None): assert l.otype.__name__ == self.__name__ if c is None: r = [None] * self.shape[0] - a = gpt.separate_indices(l, (0, lambda: ot_singlet), self.cache) + a = gpt.separate_indices(l, (0, lambda: ot_singlet()), self.cache) for i in a: r[i[0]] = a[i] return r diff --git a/lib/gpt/core/object_type/u_1.py b/lib/gpt/core/object_type/u_1.py index fb0724f96..936b845c3 100644 --- a/lib/gpt/core/object_type/u_1.py +++ b/lib/gpt/core/object_type/u_1.py @@ -34,13 +34,13 @@ def identity(self): def __init__(self, name): self.__name__ = name - self.data_alias = lambda: ot_singlet + self.data_alias = lambda: ot_singlet() self.rmtab = { - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } self.mtab = { self.__name__: (lambda: self, None), - "ot_singlet": (lambda: ot_singlet, None), + "ot_singlet": (lambda: ot_singlet(), None), } diff --git a/lib/gpt/core/tensor.py b/lib/gpt/core/tensor.py index 9ef9a5ee8..e52a65356 100644 --- a/lib/gpt/core/tensor.py +++ b/lib/gpt/core/tensor.py @@ -92,7 +92,7 @@ def trace(self, t): if ct[0] is not None: res = tensor(np.trace(res.array, offset=0, axis1=ct[0], axis2=ct[1]), ct[2]()) - if res.otype == gpt.ot_singlet: + if isinstance(res.otype, gpt.ot_singlet): res = complex(res.array) return res diff --git a/lib/gpt/core/util.py b/lib/gpt/core/util.py index d4b03c41e..2278f8154 100644 --- a/lib/gpt/core/util.py +++ b/lib/gpt/core/util.py @@ -52,7 +52,7 @@ def to_num(x): # tensor def value_to_tensor(val, otype): - if otype.data_otype() == gpt.ot_singlet: + if isinstance(otype.data_otype(), gpt.ot_singlet): # this is not ideal, can we do a subclass of complex that preserves otype info? return complex(val) return gpt.tensor(val, otype) diff --git a/lib/gpt/create/sparse_grid.py b/lib/gpt/create/sparse_grid.py index d24d7d281..99e28328f 100644 --- a/lib/gpt/create/sparse_grid.py +++ b/lib/gpt/create/sparse_grid.py @@ -26,7 +26,7 @@ def coordinates(src, position, spacing): def zn(src, position, spacing, rng, n): - singlet = gpt.lattice(src.grid, gpt.ot_singlet) + singlet = gpt.lattice(src.grid, gpt.ot_singlet()) singlet.checkerboard(src.checkerboard()) pos = coordinates(src, position, spacing) singlet_full = gpt.lattice(singlet) diff --git a/lib/gpt/ml/layer/nearest_neighbor.py b/lib/gpt/ml/layer/nearest_neighbor.py index e4c67ae45..24145490c 100644 --- a/lib/gpt/ml/layer/nearest_neighbor.py +++ b/lib/gpt/ml/layer/nearest_neighbor.py @@ -22,7 +22,7 @@ class nearest_neighbor(cshift): - def __init__(self, grid, ot_input=g.ot_singlet, ot_weights=g.ot_singlet, activation=sigmoid): + def __init__(self, grid, ot_input=g.ot_singlet(), ot_weights=g.ot_singlet(), activation=sigmoid): nd = grid.nd super().__init__( grid,