From cfcd1e2c506ee7b0b785eff255e1181f05eb9b79 Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Mon, 8 Jul 2024 20:57:54 +0200 Subject: [PATCH] adjust --- lib/gpt/ad/reverse/node.py | 30 +++++++++++++++++++++--- lib/gpt/ad/reverse/util.py | 48 +++++++++++++++++++++++++------------- lib/gpt/core/util.py | 8 +++---- tests/qcd/gauge.py | 4 ++-- 4 files changed, 65 insertions(+), 25 deletions(-) diff --git a/lib/gpt/ad/reverse/node.py b/lib/gpt/ad/reverse/node.py index a0b0ac24..c1d62959 100644 --- a/lib/gpt/ad/reverse/node.py +++ b/lib/gpt/ad/reverse/node.py @@ -17,7 +17,12 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # import gpt as g -from gpt.ad.reverse.util import get_container, get_mul_container, convert_container +from gpt.ad.reverse.util import ( + get_container, + get_mul_container, + get_div_container, + convert_container, +) from gpt.ad.reverse import foundation from gpt.core.foundation import base @@ -177,8 +182,27 @@ def _backward(z): def __rmul__(x, y): return node_base.__mul__(y, x) - def __truediv__(self, other): - return (1.0 / other) * self + def __truediv__(x, y): + if not isinstance(x, node_base): + x = node_base(x, with_gradient=False) + + if not isinstance(y, node_base): + y = node_base(y, with_gradient=False) + + z_container = get_div_container(x._container, y._container) + + def _forward(): + return x.value / y.value + + # not allowed to capture z, otherwise have reference loop! + def _backward(z): + # z = x / y -> dz = dx/y - x/y^2 dy + if x.with_gradient: + x.gradient += z.gradient / g.adj(y.value) + if y.with_gradient: + y.gradient -= g.adj(x.value) / y.value / y.value * z.gradient + + return node_base(_forward, _backward, (x, y), _container=z_container, _tag="/") def __neg__(self): return (-1.0) * self diff --git a/lib/gpt/ad/reverse/util.py b/lib/gpt/ad/reverse/util.py index 2213f7b8..a878c6f5 100644 --- a/lib/gpt/ad/reverse/util.py +++ b/lib/gpt/ad/reverse/util.py @@ -37,11 +37,11 @@ def __init__(self, *tag): if isinstance(tag[0], tuple): tag = tag[0] if tag[1] is None: - tag=[complex] + tag = [complex] elif tag[0] is None: - tag=[g.tensor,tag[1]] + tag = [g.tensor, tag[1]] else: - tag=[g.lattice, tag[0], tag[1]] + tag = [g.lattice, tag[0], tag[1]] self.tag = tag @@ -73,7 +73,7 @@ def zero(self): else: raise Exception("Unknown type") return r - + def __eq__(self, other): return str(self) == str(other) @@ -85,7 +85,7 @@ def __str__(self): r = r + ";" + str(self.tag[1].obj) return r - + def get_container(x): if isinstance(x, g.ad.reverse.node_base): return get_container(x.value) @@ -109,6 +109,11 @@ def get_mul_container(x, y): return container((g.expr(rx) * g.expr(ry)).container()) +def get_div_container(x, y): + assert isinstance(y.representative(), complex) + return x + + def get_unary_container(x, unary): rx = x.representative() return container(g.expr(unary(rx)).container()) @@ -129,7 +134,7 @@ def convert_container(v, x, y, operand): backward_spin_trace = False backward_color_trace = False backward_trace = False - + if v._container.tag[0] != g.lattice and c.tag[0] == g.lattice: backward_sum = True @@ -137,7 +142,7 @@ def convert_container(v, x, y, operand): if v._container.tag[-1].__name__ != c.tag[-1].__name__: rhs_otype = c.tag[-1] lhs_otype = v._container.tag[-1] - + if rhs_otype.spintrace[2] is not None: rhs_spintrace_otype = rhs_otype.spintrace[2]() if accumulate_compatible(lhs_otype, rhs_spintrace_otype): @@ -155,41 +160,52 @@ def convert_container(v, x, y, operand): rhs_otype = rhs_colortrace_otype if not accumulate_compatible(rhs_otype, lhs_otype): - raise Exception("Conversion incomplete:" + rhs_otype.__name__ + ":" + lhs_otype.__name__) + raise Exception( + "Conversion incomplete:" + rhs_otype.__name__ + ":" + lhs_otype.__name__ + ) # g.message("Need to modify to",v._container,"from",c,":",backward_sum, backward_trace, backward_spin_trace, backward_color_trace) assert backward_trace or backward_color_trace or backward_spin_trace or backward_sum - def _forward(): value = v.value - #if backward_sum: + # if backward_sum: # r = c.representative() # r[:] = value # value = r # print("test",backward_trace,backward_spin_trace,backward_color_trace) - + return value def _backward(z): if v.with_gradient: gradient = z.gradient - + if backward_trace: gradient = g.trace(gradient) - + if backward_color_trace: gradient = g.color_trace(gradient) if backward_spin_trace: gradient = g.spin_trace(gradient) - + if backward_sum: gradient = g.sum(gradient) - v.gradient += g(gradient) if backward_trace or backward_color_trace or backward_spin_trace else gradient + v.gradient += ( + g(gradient) + if backward_trace or backward_color_trace or backward_spin_trace + else gradient + ) # print("Ran conversion with sum/tr",backward_sum,backward_trace,backward_spin_trace,backward_color_trace) - return g.ad.reverse.node_base(_forward, _backward, (v,), _container=c, _tag="change to " + str(c) + " from " + str(v._container)) + return g.ad.reverse.node_base( + _forward, + _backward, + (v,), + _container=c, + _tag="change to " + str(c) + " from " + str(v._container), + ) diff --git a/lib/gpt/core/util.py b/lib/gpt/core/util.py index 2278f815..db6feb8b 100644 --- a/lib/gpt/core/util.py +++ b/lib/gpt/core/util.py @@ -24,14 +24,14 @@ # test if of number type def is_num(x): - return isinstance(x, (int, float, complex, gpt.qfloat, gpt.qcomplex)) and not isinstance( - x, bool - ) + return isinstance( + x, (int, float, complex, np.int64, gpt.qfloat, gpt.qcomplex) + ) and not isinstance(x, bool) # adj a number def adj_num(x): - if isinstance(x, (int, float, gpt.qfloat)): + if isinstance(x, (int, float, gpt.qfloat, np.int64)): return x elif isinstance(x, complex): return x.conjugate() diff --git a/tests/qcd/gauge.py b/tests/qcd/gauge.py index 2bf80a46..c7d8c1fb 100755 --- a/tests/qcd/gauge.py +++ b/tests/qcd/gauge.py @@ -173,8 +173,8 @@ # test local (factorizable) stout smearing lsm = g.qcd.gauge.smear.local_stout(rho=0.05, dimension=1, checkerboard=g.even) action_sm = action.transformed(lsm) - action_sm.assert_gradient_error(rng, U, U, 1e-3, 1e-8) - lsm.assert_log_det_jacobian(U, 1e-5, (2, 2, 2, 0), 1e-8) + action_sm.assert_gradient_error(rng, U, U, 1e-3, 1e-7) + lsm.assert_log_det_jacobian(U, 1e-5, (2, 2, 2, 0), 1e-7) action_log_det = lsm.action_log_det_jacobian() action_log_det.assert_gradient_error(rng, U, U, 1e-3, 1e-8)