Skip to content

Commit

Permalink
adjust
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jul 8, 2024
1 parent 98e6f3c commit cfcd1e2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 25 deletions.
30 changes: 27 additions & 3 deletions lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g
from gpt.ad.reverse.util import get_container, get_mul_container, convert_container
from gpt.ad.reverse.util import (
get_container,
get_mul_container,
get_div_container,
convert_container,
)
from gpt.ad.reverse import foundation
from gpt.core.foundation import base

Expand Down Expand Up @@ -177,8 +182,27 @@ def _backward(z):
def __rmul__(x, y):
return node_base.__mul__(y, x)

def __truediv__(self, other):
return (1.0 / other) * self
def __truediv__(x, y):
if not isinstance(x, node_base):
x = node_base(x, with_gradient=False)

if not isinstance(y, node_base):
y = node_base(y, with_gradient=False)

z_container = get_div_container(x._container, y._container)

def _forward():
return x.value / y.value

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
# z = x / y -> dz = dx/y - x/y^2 dy
if x.with_gradient:
x.gradient += z.gradient / g.adj(y.value)
if y.with_gradient:
y.gradient -= g.adj(x.value) / y.value / y.value * z.gradient

return node_base(_forward, _backward, (x, y), _container=z_container, _tag="/")

def __neg__(self):
return (-1.0) * self
Expand Down
48 changes: 32 additions & 16 deletions lib/gpt/ad/reverse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, *tag):
if isinstance(tag[0], tuple):
tag = tag[0]
if tag[1] is None:
tag=[complex]
tag = [complex]
elif tag[0] is None:
tag=[g.tensor,tag[1]]
tag = [g.tensor, tag[1]]
else:
tag=[g.lattice, tag[0], tag[1]]
tag = [g.lattice, tag[0], tag[1]]

self.tag = tag

Expand Down Expand Up @@ -73,7 +73,7 @@ def zero(self):
else:
raise Exception("Unknown type")
return r

def __eq__(self, other):
return str(self) == str(other)

Expand All @@ -85,7 +85,7 @@ def __str__(self):
r = r + ";" + str(self.tag[1].obj)
return r


def get_container(x):
if isinstance(x, g.ad.reverse.node_base):
return get_container(x.value)
Expand All @@ -109,6 +109,11 @@ def get_mul_container(x, y):
return container((g.expr(rx) * g.expr(ry)).container())


def get_div_container(x, y):
assert isinstance(y.representative(), complex)
return x


def get_unary_container(x, unary):
rx = x.representative()
return container(g.expr(unary(rx)).container())
Expand All @@ -129,15 +134,15 @@ def convert_container(v, x, y, operand):
backward_spin_trace = False
backward_color_trace = False
backward_trace = False

if v._container.tag[0] != g.lattice and c.tag[0] == g.lattice:
backward_sum = True

# now check otypes
if v._container.tag[-1].__name__ != c.tag[-1].__name__:
rhs_otype = c.tag[-1]
lhs_otype = v._container.tag[-1]

if rhs_otype.spintrace[2] is not None:
rhs_spintrace_otype = rhs_otype.spintrace[2]()
if accumulate_compatible(lhs_otype, rhs_spintrace_otype):
Expand All @@ -155,41 +160,52 @@ def convert_container(v, x, y, operand):
rhs_otype = rhs_colortrace_otype

if not accumulate_compatible(rhs_otype, lhs_otype):
raise Exception("Conversion incomplete:" + rhs_otype.__name__ + ":" + lhs_otype.__name__)
raise Exception(
"Conversion incomplete:" + rhs_otype.__name__ + ":" + lhs_otype.__name__
)

# g.message("Need to modify to",v._container,"from",c,":",backward_sum, backward_trace, backward_spin_trace, backward_color_trace)
assert backward_trace or backward_color_trace or backward_spin_trace or backward_sum


def _forward():
value = v.value

#if backward_sum:
# if backward_sum:
# r = c.representative()
# r[:] = value
# value = r
# print("test",backward_trace,backward_spin_trace,backward_color_trace)

return value

def _backward(z):
if v.with_gradient:
gradient = z.gradient

if backward_trace:
gradient = g.trace(gradient)

if backward_color_trace:
gradient = g.color_trace(gradient)

if backward_spin_trace:
gradient = g.spin_trace(gradient)

if backward_sum:
gradient = g.sum(gradient)

v.gradient += g(gradient) if backward_trace or backward_color_trace or backward_spin_trace else gradient
v.gradient += (
g(gradient)
if backward_trace or backward_color_trace or backward_spin_trace
else gradient
)

# print("Ran conversion with sum/tr",backward_sum,backward_trace,backward_spin_trace,backward_color_trace)

return g.ad.reverse.node_base(_forward, _backward, (v,), _container=c, _tag="change to " + str(c) + " from " + str(v._container))
return g.ad.reverse.node_base(
_forward,
_backward,
(v,),
_container=c,
_tag="change to " + str(c) + " from " + str(v._container),
)
8 changes: 4 additions & 4 deletions lib/gpt/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

# test if of number type
def is_num(x):
return isinstance(x, (int, float, complex, gpt.qfloat, gpt.qcomplex)) and not isinstance(
x, bool
)
return isinstance(
x, (int, float, complex, np.int64, gpt.qfloat, gpt.qcomplex)
) and not isinstance(x, bool)


# adj a number
def adj_num(x):
if isinstance(x, (int, float, gpt.qfloat)):
if isinstance(x, (int, float, gpt.qfloat, np.int64)):
return x
elif isinstance(x, complex):
return x.conjugate()
Expand Down
4 changes: 2 additions & 2 deletions tests/qcd/gauge.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@
# test local (factorizable) stout smearing
lsm = g.qcd.gauge.smear.local_stout(rho=0.05, dimension=1, checkerboard=g.even)
action_sm = action.transformed(lsm)
action_sm.assert_gradient_error(rng, U, U, 1e-3, 1e-8)
lsm.assert_log_det_jacobian(U, 1e-5, (2, 2, 2, 0), 1e-8)
action_sm.assert_gradient_error(rng, U, U, 1e-3, 1e-7)
lsm.assert_log_det_jacobian(U, 1e-5, (2, 2, 2, 0), 1e-7)

action_log_det = lsm.action_log_det_jacobian()
action_log_det.assert_gradient_error(rng, U, U, 1e-3, 1e-8)
Expand Down

0 comments on commit cfcd1e2

Please sign in to comment.