Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Mar 25, 2024
1 parent 802b133 commit 6cb9aa8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
1 change: 1 addition & 0 deletions lib/gpt/ad/reverse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None):
grid, rhs_otype, is_list, nlist = rhs_gradient.container()
assert not is_list # for now
lhs_otype = lhs_gradient.otype

if lhs_otype.__name__ != rhs_otype.__name__:
if rhs_otype.spintrace[2] is not None:
rhs_spintrace_otype = rhs_otype.spintrace[2]()
Expand Down
46 changes: 33 additions & 13 deletions tests/ad/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
b2 = rad.node(g.vspincolor(grid))
x = rad.node(g.vspincolor(grid))
t1 = rad.node(g.tensor(a1.value.otype))
s1 = rad.node(g.complex(grid))
s2 = rad.node(g.complex(grid))
m1 = rad.node(g.mspin(grid))

# use additive group instead of matrix multiplication
u1 = rad.node(g.mcolor(grid), infinitesimal_to_cartesian=False)

# relu without leakage
relu = g.component.relu()
Expand All @@ -32,21 +38,28 @@ def real(x):
return 0.5 * (x + g.adj(x))

# test a few simple models
for c, learn_rate in [
nid = 0
for c, learn_rate, args in [
(
g.norm2(b1 + 1j * b2)
+ g.inner_product(a1 + 1j * a2, a1 - 1j * a2)
+ g.norm2(t1)
+ g.norm2(x),
1e-1,
[b1, b2, a1, a2, t1, x],
),
(
g.norm2(a1)
+ 3.0 * g.norm2(a2 * b1 + b2 + t1 * x)
+ g.inner_product(b1 + 1j * b2, b1 + 1j * b2),
1e-1,
[b1, b2, a1, a2, t1, x],
),
(
g.norm2(relu(a2 * relu(a1 * x + b1) + g.adj(t1 * x + b2)) - x),
1e-1,
[b1, b2, a1, a2, t1, x],
),
(g.norm2(relu(a2 * relu(a1 * x + b1) + g.adj(t1 * x + b2)) - x), 1e-1),
(
g.norm2(
2.0 * a2 * t1 * a1 * relu(a1 * x + b1)
Expand All @@ -56,19 +69,29 @@ def real(x):
+ t1 * g.cshift(a1 * x, 1, -1)
),
1e-1,
[b1, b2, a1, a2, t1, x],
),
(g.norm2(s1 * x + s2 * x), 1e-1, [s1, s2, x]),
(g.norm2((s1 + s2) * x), 1e-1, [s1, s2, x]),
(g.norm2(s1 + s2), 1e-1, [s1, s2]),
(g.norm2(s1 * s2), 1e-1, [s1, s2]),
(g.norm2((s1 * s2) * x), 1e-1, [s1, s2, x]),
(g.norm2(u1 * x), 1e-1, [x, u1]),
(g.norm2(m1 * x + u1 * x), 1e-1, [m1, x, u1]),
]:
# randomize values
rng.cnormal([a1.value, a2.value, b1.value, b2.value, x.value, t1.value])
rng.cnormal([vv.value for vv in args])

# get gradient for real and imaginary part
for ig, part in [(1.0, lambda x: x.real), (1.0j, lambda x: x.imag)]:
v0 = c(initial_gradient=ig)

# numerically test derivatives
eps = 1e-6
g.message(f"Numerical derivatives with eps = {eps} with initial gradient {ig}")
for var in [a1, a2, b1, b2, x, t1]:
g.message(
f"Numerical derivatives of expression {nid} with eps = {eps} with initial gradient {ig}"
)
for var in args:
lt = rng.normal(var.value.new())
var.value += lt * eps
v1 = part(c(with_gradients=False))
Expand All @@ -78,9 +101,7 @@ def real(x):

num_result = (v1 - v2) / eps / 2.0
ad_result = g.inner_product(lt, var.gradient).real
err = abs(num_result - ad_result) / (
abs(num_result) + abs(ad_result) + grid.gsites
)
err = abs(num_result - ad_result) / (abs(num_result) + abs(ad_result) + grid.gsites)
g.message(f"Error of gradient's real part: {err} {num_result} {ad_result}")
assert err < 1e-4

Expand All @@ -92,21 +113,20 @@ def real(x):

num_result = (v1 - v2) / eps / 2.0
ad_result = g.inner_product(lt, var.gradient).imag
err = abs(num_result - ad_result) / (
abs(num_result) + abs(ad_result) + grid.gsites
)
err = abs(num_result - ad_result) / (abs(num_result) + abs(ad_result) + grid.gsites)
g.message(f"Error of gradient's imaginary part: {err} {num_result} {ad_result}")
assert err < 1e-4

# create something to minimize
f = real(c).functional(a1, b1, a2, b2, t1)
ff = [a1.value, b1.value, a2.value, b2.value, t1.value]
f = real(c).functional(*args)
ff = [vv.value for vv in args]
v0 = f(ff)
opt = g.algorithms.optimize.adam(maxiter=40, eps=1e-7, alpha=learn_rate)
opt(f)(ff, ff)
v1 = f(ff)
g.message(f"Reduced value from {v0} to {v1} with Adam")
assert v1 < v0
nid += 1


#####################################
Expand Down

0 comments on commit 6cb9aa8

Please sign in to comment.