From 802b133016b6808abb3f90d60c06e47b8bb515a6 Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Mon, 25 Mar 2024 11:20:23 +0100 Subject: [PATCH] type system 2 --- lib/gpt/ad/reverse/util.py | 19 ++++++++++++++++--- lib/gpt/core/object_type/container.py | 5 ++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/gpt/ad/reverse/util.py b/lib/gpt/ad/reverse/util.py index 48839dd7..98cbc12e 100644 --- a/lib/gpt/ad/reverse/util.py +++ b/lib/gpt/ad/reverse/util.py @@ -36,6 +36,14 @@ def is_field(x): raise Exception(f"Unknown object type {type(x)}") +def accumulate_compatible(a, b): + if a.data_alias is not None: + a = a.data_alias() + if b.data_alias is not None: + b = b.data_alias() + return a.__name__ == b.__name__ + + def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None): lhs_gradient = lhs.gradient if getter is not None: @@ -49,15 +57,20 @@ def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None): if isinstance(lhs_gradient, g.lattice) and isinstance(rhs_gradient, g.expr): grid, rhs_otype, is_list, nlist = rhs_gradient.container() - assert not is_list # for now + assert not is_list # for now lhs_otype = lhs_gradient.otype if lhs_otype.__name__ != rhs_otype.__name__: if rhs_otype.spintrace[2] is not None: - if lhs_otype.__name__ == rhs_otype.spintrace[2]().__name__: + rhs_spintrace_otype = rhs_otype.spintrace[2]() + if accumulate_compatible(lhs_otype, rhs_spintrace_otype): rhs_gradient = g(g.spin_trace(rhs_gradient)) rhs_otype = rhs_gradient.otype + elif rhs_spintrace_otype.colortrace[2] is not None: + if accumulate_compatible(lhs_otype, rhs_spintrace_otype.colortrace[2]()): + rhs_gradient = g(g.trace(rhs_gradient)) + rhs_otype = rhs_gradient.otype if rhs_otype.colortrace[2] is not None: - if lhs_otype.__name__ == rhs_otype.colortrace[2]().__name__: + if accumulate_compatible(lhs_otype, rhs_otype.colortrace[2]()): rhs_gradient = g(g.color_trace(rhs_gradient)) rhs_otype = rhs_gradient.otype diff --git a/lib/gpt/core/object_type/container.py b/lib/gpt/core/object_type/container.py index ff110d0f..40cc6ccf 100644 --- a/lib/gpt/core/object_type/container.py +++ b/lib/gpt/core/object_type/container.py @@ -74,9 +74,12 @@ def data_otype(self): def is_self_dual(self): return True - def identity(): + def identity(self): return 1.0 + def infinitesimal_to_cartesian(self, a, da): + return da + ### # Matrices and vectors in color space