From 6cb9aa8fb0e5f5c17fa74715f5557a67d98d908d Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Mon, 25 Mar 2024 14:00:58 +0100 Subject: [PATCH] add tests --- lib/gpt/ad/reverse/util.py | 1 + tests/ad/ad.py | 46 +++++++++++++++++++++++++++----------- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/lib/gpt/ad/reverse/util.py b/lib/gpt/ad/reverse/util.py index 98cbc12e..c75bf5e6 100644 --- a/lib/gpt/ad/reverse/util.py +++ b/lib/gpt/ad/reverse/util.py @@ -59,6 +59,7 @@ def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None): grid, rhs_otype, is_list, nlist = rhs_gradient.container() assert not is_list # for now lhs_otype = lhs_gradient.otype + if lhs_otype.__name__ != rhs_otype.__name__: if rhs_otype.spintrace[2] is not None: rhs_spintrace_otype = rhs_otype.spintrace[2]() diff --git a/tests/ad/ad.py b/tests/ad/ad.py index 591266ae..2a4ac892 100755 --- a/tests/ad/ad.py +++ b/tests/ad/ad.py @@ -24,6 +24,12 @@ b2 = rad.node(g.vspincolor(grid)) x = rad.node(g.vspincolor(grid)) t1 = rad.node(g.tensor(a1.value.otype)) + s1 = rad.node(g.complex(grid)) + s2 = rad.node(g.complex(grid)) + m1 = rad.node(g.mspin(grid)) + + # use additive group instead of matrix multiplication + u1 = rad.node(g.mcolor(grid), infinitesimal_to_cartesian=False) # relu without leakage relu = g.component.relu() @@ -32,21 +38,28 @@ def real(x): return 0.5 * (x + g.adj(x)) # test a few simple models - for c, learn_rate in [ + nid = 0 + for c, learn_rate, args in [ ( g.norm2(b1 + 1j * b2) + g.inner_product(a1 + 1j * a2, a1 - 1j * a2) + g.norm2(t1) + g.norm2(x), 1e-1, + [b1, b2, a1, a2, t1, x], ), ( g.norm2(a1) + 3.0 * g.norm2(a2 * b1 + b2 + t1 * x) + g.inner_product(b1 + 1j * b2, b1 + 1j * b2), 1e-1, + [b1, b2, a1, a2, t1, x], + ), + ( + g.norm2(relu(a2 * relu(a1 * x + b1) + g.adj(t1 * x + b2)) - x), + 1e-1, + [b1, b2, a1, a2, t1, x], ), - (g.norm2(relu(a2 * relu(a1 * x + b1) + g.adj(t1 * x + b2)) - x), 1e-1), ( g.norm2( 2.0 * a2 * t1 * a1 * relu(a1 * x + b1) @@ -56,10 +69,18 @@ def real(x): + t1 * g.cshift(a1 * x, 1, -1) ), 1e-1, + [b1, b2, a1, a2, t1, x], ), + (g.norm2(s1 * x + s2 * x), 1e-1, [s1, s2, x]), + (g.norm2((s1 + s2) * x), 1e-1, [s1, s2, x]), + (g.norm2(s1 + s2), 1e-1, [s1, s2]), + (g.norm2(s1 * s2), 1e-1, [s1, s2]), + (g.norm2((s1 * s2) * x), 1e-1, [s1, s2, x]), + (g.norm2(u1 * x), 1e-1, [x, u1]), + (g.norm2(m1 * x + u1 * x), 1e-1, [m1, x, u1]), ]: # randomize values - rng.cnormal([a1.value, a2.value, b1.value, b2.value, x.value, t1.value]) + rng.cnormal([vv.value for vv in args]) # get gradient for real and imaginary part for ig, part in [(1.0, lambda x: x.real), (1.0j, lambda x: x.imag)]: @@ -67,8 +88,10 @@ def real(x): # numerically test derivatives eps = 1e-6 - g.message(f"Numerical derivatives with eps = {eps} with initial gradient {ig}") - for var in [a1, a2, b1, b2, x, t1]: + g.message( + f"Numerical derivatives of expression {nid} with eps = {eps} with initial gradient {ig}" + ) + for var in args: lt = rng.normal(var.value.new()) var.value += lt * eps v1 = part(c(with_gradients=False)) @@ -78,9 +101,7 @@ def real(x): num_result = (v1 - v2) / eps / 2.0 ad_result = g.inner_product(lt, var.gradient).real - err = abs(num_result - ad_result) / ( - abs(num_result) + abs(ad_result) + grid.gsites - ) + err = abs(num_result - ad_result) / (abs(num_result) + abs(ad_result) + grid.gsites) g.message(f"Error of gradient's real part: {err} {num_result} {ad_result}") assert err < 1e-4 @@ -92,21 +113,20 @@ def real(x): num_result = (v1 - v2) / eps / 2.0 ad_result = g.inner_product(lt, var.gradient).imag - err = abs(num_result - ad_result) / ( - abs(num_result) + abs(ad_result) + grid.gsites - ) + err = abs(num_result - ad_result) / (abs(num_result) + abs(ad_result) + grid.gsites) g.message(f"Error of gradient's imaginary part: {err} {num_result} {ad_result}") assert err < 1e-4 # create something to minimize - f = real(c).functional(a1, b1, a2, b2, t1) - ff = [a1.value, b1.value, a2.value, b2.value, t1.value] + f = real(c).functional(*args) + ff = [vv.value for vv in args] v0 = f(ff) opt = g.algorithms.optimize.adam(maxiter=40, eps=1e-7, alpha=learn_rate) opt(f)(ff, ff) v1 = f(ff) g.message(f"Reduced value from {v0} to {v1} with Adam") assert v1 < v0 + nid += 1 #####################################