Skip to content

Commit

Permalink
type system 2
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Mar 25, 2024
1 parent 914efb3 commit 802b133
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
19 changes: 16 additions & 3 deletions lib/gpt/ad/reverse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def is_field(x):
raise Exception(f"Unknown object type {type(x)}")


def accumulate_compatible(a, b):
if a.data_alias is not None:
a = a.data_alias()
if b.data_alias is not None:
b = b.data_alias()
return a.__name__ == b.__name__


def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None):
lhs_gradient = lhs.gradient
if getter is not None:
Expand All @@ -49,15 +57,20 @@ def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None):

if isinstance(lhs_gradient, g.lattice) and isinstance(rhs_gradient, g.expr):
grid, rhs_otype, is_list, nlist = rhs_gradient.container()
assert not is_list # for now
assert not is_list # for now
lhs_otype = lhs_gradient.otype
if lhs_otype.__name__ != rhs_otype.__name__:
if rhs_otype.spintrace[2] is not None:
if lhs_otype.__name__ == rhs_otype.spintrace[2]().__name__:
rhs_spintrace_otype = rhs_otype.spintrace[2]()
if accumulate_compatible(lhs_otype, rhs_spintrace_otype):
rhs_gradient = g(g.spin_trace(rhs_gradient))
rhs_otype = rhs_gradient.otype
elif rhs_spintrace_otype.colortrace[2] is not None:
if accumulate_compatible(lhs_otype, rhs_spintrace_otype.colortrace[2]()):
rhs_gradient = g(g.trace(rhs_gradient))
rhs_otype = rhs_gradient.otype
if rhs_otype.colortrace[2] is not None:
if lhs_otype.__name__ == rhs_otype.colortrace[2]().__name__:
if accumulate_compatible(lhs_otype, rhs_otype.colortrace[2]()):
rhs_gradient = g(g.color_trace(rhs_gradient))
rhs_otype = rhs_gradient.otype

Expand Down
5 changes: 4 additions & 1 deletion lib/gpt/core/object_type/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ def data_otype(self):
def is_self_dual(self):
return True

def identity():
def identity(self):
return 1.0

def infinitesimal_to_cartesian(self, a, da):
return da


###
# Matrices and vectors in color space
Expand Down

0 comments on commit 802b133

Please sign in to comment.