Skip to content

Commit

Permalink
keep work
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jul 8, 2024
1 parent 27517e6 commit 3a5d777
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 127 deletions.
41 changes: 29 additions & 12 deletions lib/gpt/ad/reverse/foundation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g
from gpt.ad.reverse.util import accumulate_gradient
from gpt.ad.reverse.util import container, get_unary_container


def inner_product(x, y, use_accelerator):
Expand All @@ -31,11 +31,11 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient: # z = adj(x) y -> x z = y -> x = y adj(z)
accumulate_gradient(x, y.value * g.adj(z.gradient))
x.gradient += y.value * g.adj(z.gradient) # y.value * g.adj(z.gradient)
if y.with_gradient: # z = adj(x) y -> y = x z
accumulate_gradient(y, x.value * z.gradient)
y.gradient += x.value * z.gradient

return {(0, 0): g.ad.reverse.node_base(_forward, _backward, (x, y))}
return {(0, 0): g.ad.reverse.node_base(_forward, _backward, (x, y), _container=container(complex), _tag="inner_product")}


def norm2(x):
Expand All @@ -52,9 +52,9 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, g.cshift(z.gradient, direction, -displacement))
x.gradient += g.cshift(z.gradient, direction, -displacement)

return g.ad.reverse.node_base(_forward, _backward, (x,))
return g.ad.reverse.node_base(_forward, _backward, (x,), _container=x._container, _tag="cshift(" + str(direction) + ", " + str(displacement) + ")")


def adj(x):
Expand All @@ -64,9 +64,9 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, g.adj(z.gradient))
x.gradient += g.adj(z.gradient)

return g.ad.reverse.node_base(_forward, _backward, (x,))
return g.ad.reverse.node_base(_forward, _backward, (x,), _container=x._container, _tag="adj")


def trace(x, t):
Expand All @@ -76,9 +76,11 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, g.identity(x.value) * z.gradient)
x.gradient += g.identity(x.value) * z.gradient

return g.ad.reverse.node_base(_forward, _backward, (x,))
z_container = get_unary_container(x._container, lambda v: g.trace(v, t))

return g.ad.reverse.node_base(_forward, _backward, (x,), _container=z_container)


def sum(x):
Expand All @@ -88,13 +90,28 @@ def _forward():
# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, g.identity(x.value) * z.gradient)
x.gradient += g.identity(x.value) * z.gradient

return g.ad.reverse.node_base(_forward, _backward, (x,))
return g.ad.reverse.node_base(_forward, _backward, (x,), _container=x._container.lattice_to_tensor())


def component_simple_map(operator, numpy_operator, extra_params, first, second):
if operator == "relu":
assert second is None
return g.ad.reverse.transform.relu(first, a=extra_params["a"])
raise Exception(f"component-wise operator {operator} not implemented in rev-AD")


def infinitesimal_to_cartesian(src, dsrc):
return src.value.otype.infinitesimal_to_cartesian(src, dsrc)


def identity(x):
def _forward():
return g.identity(x.value)

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
pass

return g.ad.reverse.node_base(_forward, _backward, (x,), _container=x._container, _tag="identity(" + str(x._container) + ")")
96 changes: 72 additions & 24 deletions lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g
from gpt.ad.reverse.util import accumulate_gradient, is_field
from gpt.ad.reverse.util import get_container, get_mul_container, convert_container
from gpt.ad.reverse import foundation
from gpt.core.foundation import base

Expand Down Expand Up @@ -73,8 +73,22 @@ def gradient(self, fields, dfields):
return [self.arguments[i].gradient for i in indices]


# gctr = 0
def str_traverse(node, indent=0):
if not callable(node._forward):
return (" "*indent) + "leaf(" + str(node._container) + ")"
else:
pre = " "*indent
if node._tag is not None:
tag = node._tag
else:
tag = str(node._forward)
ret = pre + "(" + tag + "):"
for x in node._children:
ret = ret + "\n" + str_traverse(x, indent+1)
return ret


# gctr = 0

class node_base(base):
foundation = foundation
Expand All @@ -87,30 +101,47 @@ def __init__(
_children=(),
with_gradient=True,
infinitesimal_to_cartesian=True,
_container=None,
_tag=None
):
# global gctr
# gctr+=1
if not callable(_forward):
if not callable(_forward) or isinstance(_forward, node_base):
self._forward = None
self.value = _forward
_container = get_container(_forward)
else:
self._forward = _forward
self.value = None
assert _container is not None
self._container = _container
self._backward = _backward
self._children = _children
self.with_gradient = with_gradient
self.infinitesimal_to_cartesian = infinitesimal_to_cartesian
self.gradient = None
self._tag = _tag

# def __del__(self):
# global gctr
# gctr-=1
# print(gctr)

def __str__(self):
return str_traverse(self)

def zero_gradient(self):
self.gradient = 0.0 * self.value
if isinstance(self.gradient, g.expr):
self.gradient = g(self.gradient)
self.gradient = self._container.zero()
if isinstance(self.value, g.ad.forward.series):
gradient = 0.0 * self.value
for t in gradient.terms:
gradient.terms[t] = self.gradient
self.gradient = gradient

value = self.value
while isinstance(value, node_base):
value = value.value
self.gradient = node_base(self.gradient)

def __mul__(x, y):
if not isinstance(x, node_base):
Expand All @@ -119,17 +150,25 @@ def __mul__(x, y):
if not isinstance(y, node_base):
y = node_base(y, with_gradient=False)

z_container = get_mul_container(x._container, y._container)

if x.with_gradient:
x = convert_container(x, z_container, y._container, lambda a, b: a * g.adj(b))

if y.with_gradient:
y = convert_container(y, x._container, z_container, lambda a, b: g.adj(a) * b)

def _forward():
return x.value * y.value

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, z.gradient * g.adj(y.value))
x.gradient += z.gradient * g.adj(y.value)
if y.with_gradient:
accumulate_gradient(y, g.adj(x.value) * z.gradient)
y.gradient += g.adj(x.value) * z.gradient

return node_base(_forward, _backward, (x, y))
return node_base(_forward, _backward, (x, y), _container=z_container, _tag="*")

def __rmul__(x, y):
return node_base.__mul__(y, x)
Expand All @@ -150,13 +189,15 @@ def setter(y, z):
return x.project(getter, setter)

def project(x, getter, setter):
assert False # for future use

def _forward():
return getter(x.value)

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, z.gradient, getter, setter)
x.gradient += z.gradient

return node_base(_forward, _backward, (x,))

Expand All @@ -167,17 +208,20 @@ def __add__(x, y):
if not isinstance(y, node_base):
y = node_base(y, with_gradient=False)

assert x._container == y._container
_container = x._container

def _forward():
return x.value + y.value

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, z.gradient)
x.gradient += z.gradient
if y.with_gradient:
accumulate_gradient(y, z.gradient)
y.gradient += z.gradient

return node_base(_forward, _backward, (x, y))
return node_base(_forward, _backward, (x, y), _container=_container, _tag="+")

def __sub__(x, y):
if not isinstance(x, node_base):
Expand All @@ -186,17 +230,20 @@ def __sub__(x, y):
if not isinstance(y, node_base):
y = node_base(y, with_gradient=False)

assert x._container == y._container
_container = x._container

def _forward():
return x.value - y.value

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, z.gradient)
x.gradient += z.gradient
if y.with_gradient:
accumulate_gradient(y, -z.gradient)
y.gradient -= z.gradient

return node_base(_forward, _backward, (x, y))
return node_base(_forward, _backward, (x, y), _container=_container, _tag="-")

def __rsub__(x, y):
return node_base.__sub__(y, x)
Expand All @@ -219,7 +266,8 @@ def forward(self, nodes, eager=True, free=None):
m.value = None
fields_allocated -= 1
if eager and isinstance(n.value, g.expr):
n.value = g(n.value)
if not n.value.is_adj():
n.value = g(n.value)

if verbose_memory:
g.message(
Expand All @@ -230,16 +278,16 @@ def backward(self, nodes, first_gradient, initial_gradient):
fields_allocated = len(nodes) # .values
max_fields_allocated = fields_allocated
if initial_gradient is None:
if is_field(self.value):
if self._container.is_field():
raise Exception(
"Expression evaluates to a field. Gradient calculation is not unique."
)
if isinstance(self.value, complex) and abs(self.value.imag) > 1e-12 * abs(
self.value.real
):
raise Exception(
f"Expression does not evaluate to a real number ({self.value}). Gradient calculation is not unique."
)
#if isinstance(self._container[0], complex) and abs(self.value.imag) > 1e-12 * abs(
# self.value.real
#):
# raise Exception(
# f"Expression does not evaluate to a real number ({self.value}). Gradient calculation is not unique."
# )
initial_gradient = 1.0
self.zero_gradient()
self.gradient += initial_gradient
Expand Down
5 changes: 2 additions & 3 deletions lib/gpt/ad/reverse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#
import gpt as g
from gpt.ad.reverse import node_base
from gpt.ad.reverse.util import accumulate_gradient


def relu(x, a=0.0):
Expand All @@ -29,6 +28,6 @@ def _forward():
def _backward(z):
if x.with_gradient:
active = g.component.drelu(a)(x.value)
accumulate_gradient(x, g.component.multiply(active, z.gradient))
x.gradient += g.component.multiply(active, z.gradient)

return node_base(_forward, _backward, (x,))
return node_base(_forward, _backward, (x,), _container=x._container)
Loading

0 comments on commit 3a5d777

Please sign in to comment.