Skip to content

Commit

Permalink
first round of improvements to type system; thanks to Daniel Knuettel…
Browse files Browse the repository at this point in the history
… for the joint effort
  • Loading branch information
lehner committed Mar 23, 2024
1 parent 163ff0a commit 914efb3
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 114 deletions.
5 changes: 3 additions & 2 deletions lib/gpt/ad/reverse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def is_field(x):
elif isinstance(x, g.tensor):
return False
elif isinstance(x, g.expr):
return x.lattice() is not None
return x.container()[0] is not None
elif g.util.is_num(x):
return False
elif isinstance(x, g.ad.forward.series):
Expand All @@ -48,7 +48,8 @@ def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None):
rhs_gradient = g(rhs_gradient)

if isinstance(lhs_gradient, g.lattice) and isinstance(rhs_gradient, g.expr):
rhs_otype = rhs_gradient.lattice().otype
grid, rhs_otype, is_list, nlist = rhs_gradient.container()
assert not is_list # for now

Check failure on line 52 in lib/gpt/ad/reverse/util.py

View workflow job for this annotation

GitHub Actions / lint

E261:at least two spaces before inline comment
lhs_otype = lhs_gradient.otype
if lhs_otype.__name__ != rhs_otype.__name__:
if rhs_otype.spintrace[2] is not None:
Expand Down
2 changes: 1 addition & 1 deletion lib/gpt/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def mat_backward(dst, src):

def coordinate_mask(field, mask):
assert isinstance(mask, numpy.ndarray)
assert field.otype.data_otype() == gpt.ot_singlet
assert isinstance(field.otype.data_otype(), gpt.ot_singlet)

x = gpt.coordinates(field)
field[x] = mask.astype(field.grid.precision.complex_dtype).reshape((len(mask), 1))
Expand Down
161 changes: 80 additions & 81 deletions lib/gpt/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,73 @@ class expr_unary:
BIT_COLORTRACE = 2


def get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj):
if f_adj and not t_adj and f_otype.itab is not None:
# inner
tab = f_otype.itab
rtab = {}
elif t_adj and not f_adj and f_otype.otab is not None:
# outer
tab = f_otype.otab
rtab = {}
else:
tab = f_otype.mtab
rtab = t_otype.rmtab

if t_otype.__name__ in tab:
return tab[t_otype.__name__][0]()
else:
if f_otype.__name__ not in rtab:
if f_otype.data_alias is not None:
return get_otype_from_multiplication(t_otype, t_adj, f_otype.data_alias(), f_adj)
elif t_otype.data_alias is not None:
return get_otype_from_multiplication(t_otype.data_alias(), t_adj, f_otype, f_adj)
else:
ajd_str_t = ".H" if t_adj else ""
ajd_str_f = ".H" if f_adj else ""
gpt.message(
f"Missing entry in multiplication table: {t_otype.__name__}{ajd_str_t} x {f_otype.__name__}{ajd_str_f}"
)
return rtab[f_otype.__name__][0]()


def get_otype_from_expression(e):
bare_otype = None
for coef, term in e.val:
if len(term) == 0:
t_otype = gpt.ot_singlet()
else:
t_otype = None
t_adj = False
for unary, factor in reversed(term):
f_otype = gpt.util.to_list(factor)[0].otype
f_adj = unary == factor_unary.ADJ and not f_otype.is_self_dual()
if t_otype is None:
t_otype = f_otype
t_adj = f_adj
else:
t_otype = get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj)

if bare_otype is None:
bare_otype = t_otype
else:
# all elements of a sum must have same data type
assert t_otype.data_otype().__name__ == bare_otype.data_otype().__name__

# apply unaries
if e.unary & expr_unary.BIT_SPINTRACE:
st = bare_otype.spintrace
assert st is not None
if st[2] is not None:
bare_otype = st[2]()
if e.unary & expr_unary.BIT_COLORTRACE:
ct = bare_otype.colortrace
assert ct is not None
if ct[2] is not None:
bare_otype = ct[2]()
return bare_otype


# expr:
# - each expression can have a unary operation such as trace
# - each expression has linear combination of terms
Expand Down Expand Up @@ -84,12 +151,19 @@ def is_num(self):
def get_num(self):
return self.val[0][0]

def lattice(self):
def container(self):
# returns triple grid, otype, n ; n can be None or an integer for a list of length n
for v in self.val:
for i in v[1]:
if gpt.util.is_list_instance(i[1], gpt.lattice):
return i[1]
return None
representative = i[1]
return_list = isinstance(representative, list)
representative = gpt.util.to_list(representative)
grid = representative[0].grid
n = len(representative)
otype = get_otype_from_expression(self)
return grid, otype, return_list, n
return None, None, None, None

def __mul__(self, l):
if isinstance(l, expr):
Expand All @@ -113,7 +187,7 @@ def __mul__(self, l):
if uf & gpt.factor_unary.BIT_CONJ != 0:
lhs = lhs.conj()
res = gpt.tensor(np.tensordot(lhs, l.array, axes=mt[1]), mt[0]())
if res.otype == gpt.ot_singlet:
if isinstance(res.otype, gpt.ot_singlet):
res = complex(res.array)
return res
assert 0
Expand Down Expand Up @@ -214,73 +288,6 @@ def apply_type_right_to_left(e, t):
assert 0


def get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj):
if f_adj and not t_adj and f_otype.itab is not None:
# inner
tab = f_otype.itab
rtab = {}
elif t_adj and not f_adj and f_otype.otab is not None:
# outer
tab = f_otype.otab
rtab = {}
else:
tab = f_otype.mtab
rtab = t_otype.rmtab

if t_otype.__name__ in tab:
return tab[t_otype.__name__][0]()
else:
if f_otype.__name__ not in rtab:
if f_otype.data_alias is not None:
return get_otype_from_multiplication(t_otype, t_adj, f_otype.data_alias(), f_adj)
elif t_otype.data_alias is not None:
return get_otype_from_multiplication(t_otype.data_alias(), t_adj, f_otype, f_adj)
else:
ajd_str_t = ".H" if t_adj else ""
ajd_str_f = ".H" if f_adj else ""
gpt.message(
f"Missing entry in multiplication table: {t_otype.__name__}{ajd_str_t} x {f_otype.__name__}{ajd_str_f}"
)
return rtab[f_otype.__name__][0]()


def get_otype_from_expression(e):
bare_otype = None
for coef, term in e.val:
if len(term) == 0:
t_otype = gpt.ot_singlet
else:
t_otype = None
t_adj = False
for unary, factor in reversed(term):
f_otype = gpt.util.to_list(factor)[0].otype
f_adj = unary == factor_unary.ADJ
if t_otype is None:
t_otype = f_otype
t_adj = f_adj
else:
t_otype = get_otype_from_multiplication(t_otype, t_adj, f_otype, f_adj)

if bare_otype is None:
bare_otype = t_otype
else:
# all elements of a sum must have same data type
assert t_otype.data_otype().__name__ == bare_otype.data_otype().__name__

# apply unaries
if e.unary & expr_unary.BIT_SPINTRACE:
st = bare_otype.spintrace
assert st is not None
if st[2] is not None:
bare_otype = st[2]()
if e.unary & expr_unary.BIT_COLORTRACE:
ct = bare_otype.colortrace
assert ct is not None
if ct[2] is not None:
bare_otype = ct[2]()
return bare_otype


def expr_eval(first, second=None, ac=False):
t = gpt.timer("eval", verbose_performance)

Expand Down Expand Up @@ -316,14 +323,10 @@ def expr_eval(first, second=None, ac=False):

t("prepare")
if dst is None:
lat = e.lattice()
if lat is None:
grid, otype, return_list, nlat = e.container()
if grid is None:
# cannot evaluate to a lattice object, leave expression unevaluated
return first
return_list = isinstance(lat, list)
lat = gpt.util.to_list(lat)
grid = lat[0].grid
nlat = len(lat)

# verbose output
if verbose:
Expand All @@ -339,12 +342,8 @@ def expr_eval(first, second=None, ac=False):
ret = dst
else:
assert ac is False
t("get otype")
# now find return type
otype = get_otype_from_expression(e)

ret = []

for idx in range(nlat):
t("cgpt.eval")
res = cgpt.eval(None, e.val, e.unary, False, idx)
Expand Down
4 changes: 2 additions & 2 deletions lib/gpt/core/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def __itruediv__(self, expr):
return self

def __lt__(self, other):
assert self.otype.data_otype() == gpt.ot_singlet
assert other.otype.data_otype() == gpt.ot_singlet
assert isinstance(self.otype.data_otype(), gpt.ot_singlet)
assert isinstance(other.otype.data_otype(), gpt.ot_singlet)
res = gpt.lattice(self)
params = {"operator": "<"}
cgpt.binary(res.v_obj[0], self.v_obj[0], other.v_obj[0], params)
Expand Down
2 changes: 1 addition & 1 deletion lib/gpt/core/object_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def gpt_object(first, ot):
###
# Container objects without (lie) group structure
def singlet(grid):
return gpt_object(grid, ot_singlet)
return gpt_object(grid, ot_singlet())


def matrix_color(grid, ndim):
Expand Down
3 changes: 3 additions & 0 deletions lib/gpt/core/object_type/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ def data_otype(self):
if self.data_alias is not None:
return self.data_alias()
return self

def is_self_dual(self):
return False
10 changes: 5 additions & 5 deletions lib/gpt/core/object_type/complex_additive_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class ot_complex_additive_group(ot_singlet):

def __init__(self):
self.__name__ = "ot_complex_additive_group"
self.data_alias = lambda: ot_singlet
self.data_alias = lambda: ot_singlet()
self.rmtab = {
"ot_singlet": (lambda: ot_singlet, None),
"ot_singlet": (lambda: ot_singlet(), None),
}
self.mtab = {
self.__name__: (lambda: self, None),
"ot_singlet": (lambda: ot_singlet, None),
"ot_singlet": (lambda: ot_singlet(), None),
}

# this is always multiplicative identity, not neutral element of group
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, n):
}
self.otab = {self.__name__: (lambda: ot_matrix_complex_additive_group(n), [])}
self.itab = {
self.__name__: (lambda: ot_singlet, (0, 0)),
self.__name__: (lambda: ot_singlet(), (0, 0)),
}
self.cache = {}

Expand Down Expand Up @@ -119,7 +119,7 @@ def coordinates(self, l, c=None):
assert l.otype.__name__ == self.__name__
if c is None:
r = [None] * self.shape[0] * 2
a = gpt.separate_indices(l, (0, lambda: ot_singlet), self.cache)
a = gpt.separate_indices(l, (0, lambda: ot_singlet()), self.cache)
for i in a:
r[i[0]] = gpt.component.real(a[i])
r[i[0] + self.shape[0]] = gpt.component.imag(a[i])
Expand Down
23 changes: 13 additions & 10 deletions lib/gpt/core/object_type/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ class ot_singlet(ot_base):
colortrace = (None, None, None)
v_otype = ["ot_singlet"]
mtab = {
"ot_singlet": (lambda: ot_singlet, None),
"ot_singlet": (lambda: ot_singlet(), None),
}

def data_otype(self=None):
return ot_singlet
def data_otype(self):
return ot_singlet()

def is_self_dual(self):
return True

def identity():
return 1.0
Expand All @@ -84,7 +87,7 @@ def __init__(self, ndim):
self.shape = (ndim, ndim)
self.transposed = (1, 0)
self.spintrace = (None, None, None) # do nothing
self.colortrace = (0, 1, lambda: ot_singlet)
self.colortrace = (0, 1, lambda: ot_singlet())
self.v_otype = ["ot_mcolor%d" % ndim] # cgpt data types
self.mtab = {
self.__name__: (lambda: self, (1, 0)),
Expand Down Expand Up @@ -115,7 +118,7 @@ def __init__(self, ndim):
}
self.otab = {self.__name__: (lambda: ot_matrix_color(ndim), [])}
self.itab = {
self.__name__: (lambda: ot_singlet, (0, 0)),
self.__name__: (lambda: ot_singlet(), (0, 0)),
}

def compose(self, a, b):
Expand All @@ -133,7 +136,7 @@ def __init__(self, ndim):
self.nfloats = 2 * ndim * ndim
self.shape = (ndim, ndim)
self.transposed = (1, 0)
self.spintrace = (0, 1, lambda: ot_singlet)
self.spintrace = (0, 1, lambda: ot_singlet())
self.colortrace = (None, None, None) # do nothing
self.v_otype = ["ot_mspin%d" % ndim]
self.mtab = {
Expand Down Expand Up @@ -197,7 +200,7 @@ def __init__(self, ndim):
"ot_singlet": (lambda: self, None),
}
self.otab = {self.__name__: (lambda: ot_matrix_spin(ndim), [])}
self.itab = {self.__name__: (lambda: ot_singlet, (0, 0))}
self.itab = {self.__name__: (lambda: ot_singlet(), (0, 0))}

def compose(self, a, b):
return a + b
Expand Down Expand Up @@ -269,7 +272,7 @@ def __init__(self, spin_ndim, color_ndim):
),
}
self.itab = {
self.__name__: (lambda: ot_singlet, ([0, 1], [0, 1])),
self.__name__: (lambda: ot_singlet(), ([0, 1], [0, 1])),
}
self.mtab = {
"ot_singlet": (lambda: self, None),
Expand Down Expand Up @@ -346,7 +349,7 @@ def __init__(self, n):
"ot_singlet": (lambda: self, None),
}
self.itab = {
self.__name__: (lambda: ot_singlet, (0, 0)),
self.__name__: (lambda: ot_singlet(), (0, 0)),
}


Expand All @@ -367,7 +370,7 @@ def __init__(self, n):
self.shape = (n, n)
self.transposed = (1, 0)
self.spintrace = (None, None, None)
self.colortrace = (0, 1, lambda: ot_singlet)
self.colortrace = (0, 1, lambda: ot_singlet())
self.vector_type = ot_vector_singlet(n)
self.mtab = {
self.__name__: (lambda: self, (1, 0)),
Expand Down
Loading

0 comments on commit 914efb3

Please sign in to comment.