Skip to content

Commit

Permalink
more edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Nov 1, 2023
1 parent cc42c66 commit 8fa5938
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 32 deletions.
23 changes: 12 additions & 11 deletions lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g
from gpt.ad.reverse.util import accumulate
from gpt.ad.reverse.util import accumulate_gradient


verbose_memory = g.default.is_verbose("ad_memory")
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self, _forward, _backward=lambda z: None, _children=(), with_gradie
# print(gctr)

def zero_gradient(self):
self.gradient = g(0 * self.value)
self.gradient = g(0.0 * self.value)

def __mul__(x, y):
if not isinstance(x, node_base):
Expand All @@ -111,9 +111,9 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate(x.gradient, z.gradient * g.adj(y.value))
accumulate_gradient(x, z.gradient * g.adj(y.value))
if y.with_gradient:
accumulate(y.gradient, g.adj(x.value) * z.gradient)
accumulate_gradient(y, g.adj(x.value) * z.gradient)

return node_base(_forward, _backward, (x, y))

Expand All @@ -127,9 +127,9 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate(x.gradient, z.gradient)
accumulate_gradient(x, z.gradient)
if y.with_gradient:
accumulate(y.gradient, z.gradient)
accumulate_gradient(y, z.gradient)

return node_base(_forward, _backward, (x, y))

Expand All @@ -140,9 +140,9 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate(x.gradient, z.gradient)
accumulate_gradient(x, z.gradient)
if y.with_gradient:
accumulate(y.gradient, -z.gradient)
accumulate_gradient(y, -z.gradient)

return node_base(_forward, _backward, (x, y))

Expand Down Expand Up @@ -171,7 +171,7 @@ def forward(self, nodes, eager=True, free=None):
def backward(self, nodes, first_gradient):
fields_allocated = len(nodes) # .values
max_fields_allocated = fields_allocated
self.gradient = 1
self.gradient = 1.0
for n in reversed(nodes):
first_gradient_n = first_gradient[n]
for m in first_gradient_n:
Expand All @@ -184,8 +184,9 @@ def backward(self, nodes, first_gradient):
n.gradient = None
fields_allocated -= 1
if n._forward is not None:
n.value = None
fields_allocated -= 1
if n is not self:
n.value = None
fields_allocated -= 1

if verbose_memory:
g.message(
Expand Down
10 changes: 5 additions & 5 deletions lib/gpt/ad/reverse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#
import gpt as g
from gpt.ad.reverse import node_base
from gpt.ad.reverse.util import accumulate
from gpt.ad.reverse.util import accumulate_gradient


def inner_product(x, y):
Expand All @@ -28,9 +28,9 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate(x.gradient, y.value * g.adj(z.gradient))
accumulate_gradient(x, y.value * g.adj(z.gradient))
if y.with_gradient:
accumulate(y.gradient, x.value * g.adj(z.gradient))
accumulate_gradient(y, x.value * g.adj(z.gradient))

return node_base(_forward, _backward, (x, y))

Expand All @@ -47,7 +47,7 @@ def _forward():
def _backward(z):
if x.with_gradient:
active = g.component.drelu(a)(x.value)
accumulate(x.gradient, g.component.multiply(active, z.gradient))
accumulate_gradient(x, g.component.multiply(active, z.gradient))

return node_base(_forward, _backward, (x,))

Expand All @@ -59,6 +59,6 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate(x.gradient, g.cshift(z.gradient, direction, -displacement))
accumulate_gradient(x, g.cshift(z.gradient, direction, -displacement))

return node_base(_forward, _backward, (x,))
14 changes: 10 additions & 4 deletions lib/gpt/ad/reverse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ def is_field(x):
return False
elif isinstance(x, g.expr):
return x.lattice() is not None
elif g.util.is_num(x):
return False
else:
raise Exception(f"Unknown object type {type(x)}")


def accumulate(lhs, rhs):
if is_field(rhs) and not is_field(lhs):
rhs = g.sum(rhs)
lhs += rhs
def accumulate_gradient(lhs, rhs_gradient):
rhs_field = is_field(rhs_gradient)
lhs_field = is_field(lhs.gradient)
if rhs_field and not lhs_field:
rhs_gradient = g.sum(rhs_gradient)
if g.util.is_num(lhs.gradient) and isinstance(rhs_gradient, g.expr):
rhs_gradient = g(rhs_gradient)
lhs.gradient += rhs_gradient
24 changes: 16 additions & 8 deletions lib/gpt/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def get_single(self):
self.val[0][1][0][1],
)

def is_num(self):
return len(self.val) == 1 and len(self.val[0][1]) == 0

def get_num(self):
return self.val[0][0]

def lattice(self):
for v in self.val:
for i in v[1]:
Expand Down Expand Up @@ -297,6 +303,16 @@ def expr_eval(first, second=None, ac=False):
# apply matrix_operators
e = apply_type_right_to_left(e, gpt.matrix_operator)

t("fast return")
# fast return if already a lattice
if dst is None:
if e.is_single(gpt.lattice):
ue, uf, v = e.get_single()
if uf == factor_unary.NONE and ue == expr_unary.NONE:
return v
elif e.is_num():
return e.get_num()

t("prepare")
if dst is None:
lat = e.lattice()
Expand All @@ -308,14 +324,6 @@ def expr_eval(first, second=None, ac=False):
grid = lat[0].grid
nlat = len(lat)

t("fast return")
# fast return if already a lattice
if dst is None:
if e.is_single(gpt.lattice):
ue, uf, v = e.get_single()
if uf == factor_unary.NONE and ue == expr_unary.NONE:
return v

# verbose output
if verbose:
gpt.message("eval: " + str(e))
Expand Down
9 changes: 5 additions & 4 deletions tests/ad/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
x = rad.node(g.vspincolor(grid))
t1 = rad.node(g.tensor(a1.value.otype))

# randomize values
rng.cnormal([a1.value, a2.value, b1.value, b2.value, x.value, t1.value])

# test a few simple models
for c, learn_rate in [
(rad.norm2(a1) + 3.0*rad.norm2(a2*b1 + b2 + t1*x), 1e-1),
(rad.norm2(rad.relu(a2 * rad.relu(a1 * x + b1) + t1 * x + b2) - x), 1e-1),
(
rad.norm2(
Expand All @@ -36,10 +34,13 @@
1e-1,
),
]:
# randomize values
rng.cnormal([a1.value, a2.value, b1.value, b2.value, x.value, t1.value])

v0 = c()

# numerically test derivatives
eps = prec.eps**0.5
eps = prec.eps**0.5 * 100
g.message(f"Numerical derivatives with eps = {eps}")
for var in [a1, a2, b1, b2, x, t1]:
lt = rng.cnormal(var.value.new())
Expand Down

0 comments on commit 8fa5938

Please sign in to comment.