From 62395f0f73852bad6f806202703b5eb02166fa59 Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Mon, 6 Nov 2023 19:12:14 +0100 Subject: [PATCH] adj --- lib/gpt/ad/forward/foundation.py | 8 ++++++++ lib/gpt/ad/reverse/foundation.py | 12 +++++++++++ lib/gpt/core/foundation/lattice.py | 29 +++++++++++++++++++++++++++ lib/gpt/core/foundation/tensor.py | 6 ++++++ lib/gpt/core/operator/unary.py | 32 +++++++++--------------------- tests/ad/ad.py | 2 +- 6 files changed, 65 insertions(+), 24 deletions(-) diff --git a/lib/gpt/ad/forward/foundation.py b/lib/gpt/ad/forward/foundation.py index 8a2ff37f..b9e8902a 100644 --- a/lib/gpt/ad/forward/foundation.py +++ b/lib/gpt/ad/forward/foundation.py @@ -38,3 +38,11 @@ def cshift(sx, mu, disp, none=None): def trace(sx, t): return sx.distribute1(lambda a: g.trace(a, t)) + + +def adj(sx): + return sx.distribute1(lambda a: g.adj(a)) + + +def sum(sx): + return sx.distribute1(lambda a: g.sum(a)) diff --git a/lib/gpt/ad/reverse/foundation.py b/lib/gpt/ad/reverse/foundation.py index e32fa6d3..a87c1509 100644 --- a/lib/gpt/ad/reverse/foundation.py +++ b/lib/gpt/ad/reverse/foundation.py @@ -53,6 +53,18 @@ def _backward(z): return g.ad.reverse.node_base(_forward, _backward, (x,)) +def adj(x): + def _forward(): + return g.adj(x.value) + + # not allowed to capture z, otherwise have reference loop! + def _backward(z): + if x.with_gradient: + accumulate_gradient(x, g.adj(z.gradient)) + + return g.ad.reverse.node_base(_forward, _backward, (x,)) + + def component_simple_map(operator, numpy_operator, extra_params, first, second): if operator == "relu": assert second is None diff --git a/lib/gpt/core/foundation/lattice.py b/lib/gpt/core/foundation/lattice.py index f9433504..0528699e 100644 --- a/lib/gpt/core/foundation/lattice.py +++ b/lib/gpt/core/foundation/lattice.py @@ -74,3 +74,32 @@ def component_simple_map(operator, numpy_operator, extra_params, first, second): for i in dst.otype.v_idx: cgpt.unary(dst.v_obj[i], src.v_obj[i], {**{"operator": operator}, **extra_params}) return dst + + +def adj(l): + return gpt.adj(gpt.expr(l)) + + +def rank_sum(l): + val = [cgpt.lattice_rank_sum(x) for x in l.v_obj] + vrank = len(val) + if vrank == 1: + val = val[0] + else: + vdim = len(l.otype.shape) + if vdim == 1: + val = numpy.concatenate(val) + elif vdim == 2: + n = int(vrank**0.5) + assert n * n == vrank + val = numpy.concatenate( + [numpy.concatenate([val[i * n + j] for j in range(n)], axis=0) for i in range(n)], + axis=1, + ) + else: + raise NotImplementedError() + return gpt.util.value_to_tensor(val, l.otype) + + +def sum(l): + return l.grid.globalsum(rank_sum(l)) diff --git a/lib/gpt/core/foundation/tensor.py b/lib/gpt/core/foundation/tensor.py index 485c4b79..6571d29b 100644 --- a/lib/gpt/core/foundation/tensor.py +++ b/lib/gpt/core/foundation/tensor.py @@ -46,3 +46,9 @@ def component_simple_map(operator, numpy_operator, extra_params, first, second): res = first.new() res.array = numpy_operator(first.array) return res + + +def adj(l): + if l.transposable(): + return l.adj() + return gpt.adj(gpt.expr(l)) diff --git a/lib/gpt/core/operator/unary.py b/lib/gpt/core/operator/unary.py index 7cb7f1b7..2af99fd6 100644 --- a/lib/gpt/core/operator/unary.py +++ b/lib/gpt/core/operator/unary.py @@ -72,8 +72,10 @@ def adj(l): for a in l.val ] ) - elif (isinstance(l, gpt.tensor) and l.transposable()) or isinstance(l, gpt.matrix_operator): + elif isinstance(l, gpt.matrix_operator): return l.adj() + elif isinstance(l, gpt.core.foundation.base): + return l.__class__.foundation.adj(l) else: return adj(gpt.expr(l)) @@ -111,27 +113,11 @@ def color_trace(l): def rank_sum(e): - l = gpt.eval(e) - val = [cgpt.lattice_rank_sum(x) for x in l.v_obj] - vrank = len(val) - if vrank == 1: - val = val[0] - else: - vdim = len(l.otype.shape) - if vdim == 1: - val = np.concatenate(val) - elif vdim == 2: - n = int(vrank**0.5) - assert n * n == vrank - val = np.concatenate( - [np.concatenate([val[i * n + j] for j in range(n)], axis=0) for i in range(n)], - axis=1, - ) - else: - raise NotImplementedError() - return gpt.util.value_to_tensor(val, l.otype) - + if isinstance(e, gpt.expr): + e = gpt.eval(e) + return e.__class__.foundation.rank_sum(e) def sum(e): - l = gpt.eval(e) - return l.grid.globalsum(rank_sum(l)) + if isinstance(e, gpt.expr): + e = gpt.eval(e) + return e.__class__.foundation.sum(e) diff --git a/tests/ad/ad.py b/tests/ad/ad.py index f33a25ba..556ba7cb 100755 --- a/tests/ad/ad.py +++ b/tests/ad/ad.py @@ -31,7 +31,7 @@ # test a few simple models for c, learn_rate in [ (g.norm2(a1) + 3.0 * g.norm2(a2 * b1 + b2 + t1 * x), 1e-1), - (g.norm2(relu(a2 * relu(a1 * x + b1) + t1 * x + b2) - x), 1e-1), + (g.norm2(relu(a2 * relu(a1 * x + b1) + g.adj(t1 * x + b2)) - x), 1e-1), ( g.norm2( 2.0 * a2 * t1 * a1 * relu(a1 * x + b1)