Skip to content

Commit

Permalink
towards complete AD framework
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Nov 23, 2023
1 parent 958bf45 commit 7ddfac6
Show file tree
Hide file tree
Showing 22 changed files with 528 additions and 181 deletions.
2 changes: 1 addition & 1 deletion lib/gpt/ad/forward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#
from gpt.ad.forward.infinitesimal import infinitesimal
from gpt.ad.forward.landau import landau
from gpt.ad.forward.series import series
from gpt.ad.forward.series import series, make
import gpt.ad.forward.foundation
103 changes: 102 additions & 1 deletion lib/gpt/ad/forward/foundation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,33 @@
import gpt as g


python_sum = sum


def inner_product(sx, sy, use_accelerator):
assert len(sx) == 1 and len(sy) == 1
sx = sx[0]
sy = sy[0]
return {(0, 0): sx.distribute2(sy, lambda a, b: g.inner_product(a, b, use_accelerator))}


def rank_inner_product(sx, sy, use_accelerator):
assert len(sx) == 1 and len(sy) == 1
sx = sx[0]
sy = sy[0]
return {(0, 0): sx.distribute2(sy, lambda a, b: g.rank_inner_product(a, b, use_accelerator))}


def norm2(sx):
assert len(sx) == 1
return [inner_product(sx, sx, True)[0, 0]]


def object_rank_norm2(sx):
assert len(sx) == 1
return python_sum([g.object_rank_norm2(sx[0].terms[t]) for t in sx[0].terms])


def cshift(sx, mu, disp, none=None):
assert none is None
return sx.distribute1(lambda a: g.cshift(a, mu, disp))
Expand All @@ -49,9 +64,95 @@ def sum(sx):


def identity(sx):
idsx = g.identity(sx[1])
idsx = None
for t in sx.terms:
idsx = g.identity(sx[t])
break
assert idsx is not None
return g.ad.forward.series({g.ad.forward.infinitesimal({}): idsx}, sx.landau_O)


def infinitesimal_to_cartesian(src, dsrc):
return dsrc[1].otype.infinitesimal_to_cartesian(src, dsrc)


def group_inner_product(left, right):
return left.distribute2(right, lambda a, b: g.group.inner_product(a, b))


def copy(dst, src):
for i in range(len(dst)):
dst[i] @= src[i]


def convert(first, second):
if isinstance(second, g.ot_base) and first.otype.__name__ != second.__name__:
assert second.__name__ in first.otype.ctab
tmp = g.ad.forward.series(
{t: g.lattice(first.grid, second) for t in first.terms}, first.landau_O
)
first.otype.ctab[second.__name__](tmp, first)
tmp.otype = second
return tmp

raise Exception(f"Not yet implemented for {type(first)} x {type(second)}")


def matrix_det(sx):
def df(x, dx, maxn):
# det(sx + dsx) = det(sx(1 + sx^-1 dsx))
# = det(sx) det(1 + sx^-1 dsx)
# = det(sx) (1 + tr(sx^-1 dsx) + O(dsx^2))
# higher-order:
# det(A) = exp ln det(A) = exp tr ln A
# det(sx + dsx) = exp tr ln (sx + dsx)
# ln(sx + dsx) = ln(sx) + ln(1 + sx^-1 dsx) | correct under exp tr
# = ln(sx) + sx^-1 dsx - (1/2) sx^-1 dsx sx^-1 dsx + O(dsx^2)
# tr[...] = tr[ln(sx)] + tr[sx^-1 dsx] - 1/2 tr[sx^-1 dsx sx^-1 dsx] + ...
# exp tr[...] = det(sx) * exp(tr[sx^-1 dsx]) * exp(- 1/2 tr[sx^-1 dsx sx^-1 dsx]) * ...
# exp tr[...] = det(sx) * (1 + tr[sx^-1 dsx] + 1/2 * tr[sx^-1 dsx]^2) * (1 - 1/2 tr[sx^-1 dsx sx^-1 dsx])
# det(sx + dsx)= det(sx) * (1 + tr[sx^-1 dsx] + 1/2 * tr[sx^-1 dsx]^2) * (1 - 1/2 tr[sx^-1 dsx sx^-1 dsx])
v0 = g.ad.forward.series(g.matrix.det(x), dx.landau_O)
if maxn >= 0:
v = v0
if maxn >= 2:
adjx = dx * g.matrix.inv(x)
tr_adjx = g.trace(adjx)
v += v0 * tr_adjx
if maxn >= 3:
adjx2 = adjx * adjx
tr_adjx2 = g.trace(adjx2)
v += v0 * (tr_adjx * tr_adjx - tr_adjx2) / 2.0
if maxn >= 4:
raise Exception(f"{maxn-1}-derivative of g.matrix.det not yet implemented")
v.otype = v0.otype
return v

return sx.function(df)


def component_simple_map(operator, numpy_operator, extra_params, first, second):
if operator == "pow":
assert second is None
exponent = extra_params["exponent"]
pow = g.component.pow

def df(x, dx, maxn):
# (x + dx)**exponent = x**exponent + exponent * x**(exponent-1) dx
# + 1/2 * exponent * (exponent-1) * x**(exponent-2) dx**2
v = g.ad.forward.series(pow(exponent)(x), dx.landau_O)
fac = None
for i in range(1, maxn):
fac = ((exponent + 1 - i) / i) * (g.component.multiply(dx, fac) if i > 1 else dx)
v += g.component.multiply(
fac, g.ad.forward.series(pow(exponent - i)(x), dx.landau_O)
)
v.otype = x.otype
return v

return first.function(df)
raise Exception(f"component-wise operator {operator} not implemented in forward-AD")


def component_multiply(sx, sy):
return sx.distribute2(sy, lambda a, b: g.component.multiply(a, b))
55 changes: 44 additions & 11 deletions lib/gpt/ad/forward/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def promote(other, landau_O):
if isinstance(other, infinitesimal):
other = series({other: 1}, landau_O)
return series({other: 1}, landau_O)
elif isinstance(other, series):
return other

Expand All @@ -41,6 +41,10 @@ def __init__(self, terms, landau_O):
terms = {i0: terms}
self.terms = terms

def new(self):
terms = {t1: self.terms[t1].new() for t1 in self.terms}
return series(terms, self.landau_O)

def __str__(self):
r = ""
for t in self.terms:
Expand Down Expand Up @@ -94,15 +98,7 @@ def function(self, functional):
tn = tn * t
n += 1
maxn = max([maxn, n])
res = series({i0: functional(root, 0)}, self.landau_O)
delta = nilpotent
nfac = 1.0
for i in range(1, maxn):
nfac *= i
res += delta * functional(root, i) / nfac
if i != maxn - 1:
delta = delta * nilpotent
return res
return functional(root, nilpotent, maxn)

def __iadd__(self, other):
other = promote(other, self.landau_O)
Expand All @@ -114,6 +110,12 @@ def __iadd__(self, other):
def __mul__(self, other):
return self.distribute2(other, lambda a, b: a * b)

def __imul__(self, other):
res = self * other
self.landau_O = res.landau_O
self.terms = res.terms
return self

def __rmul__(self, other):
if g.util.is_num(other):
return self.__mul__(other)
Expand Down Expand Up @@ -182,6 +184,37 @@ def __setitem__(self, tag, value):
self.terms[tag] = value

def get_grid(self):
return self.terms[infinitesimal({})].grid
for t1 in self.terms:
return self.terms[t1].grid

def get_otype(self):
for t1 in self.terms:
return self.terms[t1].otype

def set_otype(self, otype):
for t1 in self.terms:
self.terms[t1].otype = otype

def __imatmul__(self, other):
assert self.landau_O is other.landau_O
terms = {}
for t1 in other.terms:
terms[t1] = g.copy(other.terms[t1])
self.terms = terms
return self

def get_real(self):
return self.distribute1(lambda a: a.real)

grid = property(get_grid)
real = property(get_real)
otype = property(get_otype, set_otype)


def make(landau_O, O1, *args):
x = series(O1, landau_O)
n = len(args)
assert n % 2 == 0
for i in range(n // 2):
x[args[2 * i + 0]] = args[2 * i + 1]
return x
18 changes: 11 additions & 7 deletions lib/gpt/ad/reverse/foundation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@
from gpt.ad.reverse.util import accumulate_gradient


def inner_product(x, y):
def inner_product(x, y, use_accelerator):
assert len(x) == 1 and len(y) == 1
x = x[0]
y = y[0]

def _forward():
return g.inner_product(x.value, y.value)
return g.inner_product(x.value, y.value, use_accelerator)

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
if x.with_gradient: # z = adj(x) y -> x z = y -> x = y adj(z)
accumulate_gradient(x, y.value * g.adj(z.gradient))
if y.with_gradient:
accumulate_gradient(y, x.value * g.adj(z.gradient))
if y.with_gradient: # z = adj(x) y -> y = x z
accumulate_gradient(y, x.value * z.gradient)

return g.ad.reverse.node_base(_forward, _backward, (x, y))
return {(0, 0): g.ad.reverse.node_base(_forward, _backward, (x, y))}


def norm2(x):
assert len(x) == 1
return [inner_product(x[0], x[0])]
return [g.inner_product(x, x)[0, 0]]


def cshift(x, direction, displacement, none):
Expand Down
76 changes: 66 additions & 10 deletions lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,15 @@ def gradient(self, fields, dfields):
class node_base(base):
foundation = foundation

def __init__(self, _forward, _backward=lambda z: None, _children=(), with_gradient=True):
# TODO: deprecate infinitesimal_to_cartesian and make it default
def __init__(
self,
_forward,
_backward=lambda z: None,
_children=(),
with_gradient=True,
infinitesimal_to_cartesian=True,
):
# global gctr
# gctr+=1
if not callable(_forward):
Expand All @@ -91,6 +99,7 @@ def __init__(self, _forward, _backward=lambda z: None, _children=(), with_gradie
self._backward = _backward
self._children = _children
self.with_gradient = with_gradient
self.infinitesimal_to_cartesian = infinitesimal_to_cartesian
self.gradient = None

# def __del__(self):
Expand Down Expand Up @@ -128,6 +137,29 @@ def __rmul__(x, y):
def __truediv__(self, other):
return (1.0 / other) * self

def __neg__(self):
return (-1.0) * self

def __getitem__(x, item):
def getter(y):
return y[item]

def setter(y, z):
y[item] = z

return x.project(getter, setter)

def project(x, getter, setter):
def _forward():
return getter(x.value)

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, z.gradient, getter, setter)

return node_base(_forward, _backward, (x,))

def __add__(x, y):
if not isinstance(x, node_base):
x = node_base(x, with_gradient=False)
Expand Down Expand Up @@ -194,13 +226,23 @@ def forward(self, nodes, eager=True, free=None):
f"Forward propagation through graph with {len(nodes)} nodes with maximum allocated fields: {max_fields_allocated}"
)

def backward(self, nodes, first_gradient):
def backward(self, nodes, first_gradient, initial_gradient):
fields_allocated = len(nodes) # .values
max_fields_allocated = fields_allocated
if is_field(self.value):
raise Exception("Expression evaluates to a field. Gradient calculation is not unique.")
if initial_gradient is None:
if is_field(self.value):
raise Exception(
"Expression evaluates to a field. Gradient calculation is not unique."
)
if isinstance(self.value, complex) and abs(self.value.imag) > 1e-12 * abs(
self.value.real
):
raise Exception(
f"Expression does not evaluate to a real number ({self.value}). Gradient calculation is not unique."
)
initial_gradient = 1.0
self.zero_gradient()
self.gradient += 1.0
self.gradient += initial_gradient
for n in reversed(nodes):
first_gradient_n = first_gradient[n]
for m in first_gradient_n:
Expand All @@ -216,19 +258,21 @@ def backward(self, nodes, first_gradient):
n.value = None
fields_allocated -= 1
else:
n.gradient = g.infinitesimal_to_cartesian(n.value, n.gradient)
if n.with_gradient and n.infinitesimal_to_cartesian:
n.gradient = g.infinitesimal_to_cartesian(n.value, n.gradient)

if verbose_memory:
g.message(
f"Backward propagation through graph with {len(nodes)} nodes with maximum allocated fields: {max_fields_allocated}"
)

def __call__(self, with_gradients=True):
# TODO: allow for lists of initial_gradients (could save forward runs at sake of more memory)
def __call__(self, with_gradients=True, initial_gradient=None):
nodes = []
forward_free = traverse(nodes, self)
self.forward(nodes, free=forward_free if not with_gradients else None)
if with_gradients:
self.backward(nodes, first_gradient=forward_free)
self.backward(nodes, first_gradient=forward_free, initial_gradient=initial_gradient)
nodes = None
return self.value

Expand All @@ -238,8 +282,20 @@ def functional(self, *arguments):
def get_grid(self):
return self.value.grid

def get_real(self):
def getter(y):
return y.real

def setter(y, z):
y @= z

return self.project(getter, setter)

grid = property(get_grid)
real = property(get_real)


def node(x, with_gradient=True):
return node_base(x, with_gradient=with_gradient)
def node(x, with_gradient=True, infinitesimal_to_cartesian=True):
return node_base(
x, with_gradient=with_gradient, infinitesimal_to_cartesian=infinitesimal_to_cartesian
)
Loading

0 comments on commit 7ddfac6

Please sign in to comment.