Skip to content

Commit

Permalink
adj
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Nov 6, 2023
1 parent a3e30cc commit 62395f0
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 24 deletions.
8 changes: 8 additions & 0 deletions lib/gpt/ad/forward/foundation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ def cshift(sx, mu, disp, none=None):

def trace(sx, t):
return sx.distribute1(lambda a: g.trace(a, t))


def adj(sx):
return sx.distribute1(lambda a: g.adj(a))


def sum(sx):
return sx.distribute1(lambda a: g.sum(a))
12 changes: 12 additions & 0 deletions lib/gpt/ad/reverse/foundation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def _backward(z):
return g.ad.reverse.node_base(_forward, _backward, (x,))


def adj(x):
def _forward():
return g.adj(x.value)

# not allowed to capture z, otherwise have reference loop!
def _backward(z):
if x.with_gradient:
accumulate_gradient(x, g.adj(z.gradient))

return g.ad.reverse.node_base(_forward, _backward, (x,))


Check failure on line 67 in lib/gpt/ad/reverse/foundation.py

View workflow job for this annotation

GitHub Actions / lint

W293:blank line contains whitespace
def component_simple_map(operator, numpy_operator, extra_params, first, second):
if operator == "relu":
assert second is None
Expand Down
29 changes: 29 additions & 0 deletions lib/gpt/core/foundation/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,32 @@ def component_simple_map(operator, numpy_operator, extra_params, first, second):
for i in dst.otype.v_idx:
cgpt.unary(dst.v_obj[i], src.v_obj[i], {**{"operator": operator}, **extra_params})
return dst


def adj(l):
return gpt.adj(gpt.expr(l))


def rank_sum(l):
val = [cgpt.lattice_rank_sum(x) for x in l.v_obj]
vrank = len(val)
if vrank == 1:
val = val[0]
else:
vdim = len(l.otype.shape)
if vdim == 1:
val = numpy.concatenate(val)
elif vdim == 2:
n = int(vrank**0.5)
assert n * n == vrank
val = numpy.concatenate(
[numpy.concatenate([val[i * n + j] for j in range(n)], axis=0) for i in range(n)],
axis=1,
)
else:
raise NotImplementedError()
return gpt.util.value_to_tensor(val, l.otype)


def sum(l):
return l.grid.globalsum(rank_sum(l))
6 changes: 6 additions & 0 deletions lib/gpt/core/foundation/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def component_simple_map(operator, numpy_operator, extra_params, first, second):
res = first.new()
res.array = numpy_operator(first.array)
return res


def adj(l):
if l.transposable():
return l.adj()
return gpt.adj(gpt.expr(l))
32 changes: 9 additions & 23 deletions lib/gpt/core/operator/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def adj(l):
for a in l.val
]
)
elif (isinstance(l, gpt.tensor) and l.transposable()) or isinstance(l, gpt.matrix_operator):
elif isinstance(l, gpt.matrix_operator):
return l.adj()
elif isinstance(l, gpt.core.foundation.base):
return l.__class__.foundation.adj(l)
else:
return adj(gpt.expr(l))

Expand Down Expand Up @@ -111,27 +113,11 @@ def color_trace(l):


def rank_sum(e):
l = gpt.eval(e)
val = [cgpt.lattice_rank_sum(x) for x in l.v_obj]
vrank = len(val)
if vrank == 1:
val = val[0]
else:
vdim = len(l.otype.shape)
if vdim == 1:
val = np.concatenate(val)
elif vdim == 2:
n = int(vrank**0.5)
assert n * n == vrank
val = np.concatenate(
[np.concatenate([val[i * n + j] for j in range(n)], axis=0) for i in range(n)],
axis=1,
)
else:
raise NotImplementedError()
return gpt.util.value_to_tensor(val, l.otype)

if isinstance(e, gpt.expr):
e = gpt.eval(e)
return e.__class__.foundation.rank_sum(e)

def sum(e):
l = gpt.eval(e)
return l.grid.globalsum(rank_sum(l))
if isinstance(e, gpt.expr):
e = gpt.eval(e)
return e.__class__.foundation.sum(e)
2 changes: 1 addition & 1 deletion tests/ad/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# test a few simple models
for c, learn_rate in [
(g.norm2(a1) + 3.0 * g.norm2(a2 * b1 + b2 + t1 * x), 1e-1),
(g.norm2(relu(a2 * relu(a1 * x + b1) + t1 * x + b2) - x), 1e-1),
(g.norm2(relu(a2 * relu(a1 * x + b1) + g.adj(t1 * x + b2)) - x), 1e-1),
(
g.norm2(
2.0 * a2 * t1 * a1 * relu(a1 * x + b1)
Expand Down

0 comments on commit 62395f0

Please sign in to comment.