From 9d8aad033dcca6c11e1ea8bf8582b95a4fa02d3c Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Sun, 2 Jun 2024 21:49:57 +0200 Subject: [PATCH] add diff func --- .../core/group/differentiable_functional.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/lib/gpt/core/group/differentiable_functional.py b/lib/gpt/core/group/differentiable_functional.py index fca74b08..2247e9be 100644 --- a/lib/gpt/core/group/differentiable_functional.py +++ b/lib/gpt/core/group/differentiable_functional.py @@ -111,6 +111,26 @@ def assert_gradient_error(self, rng, fields, dfields, epsilon_approx, epsilon_as def transformed(self, t): return transformed(self, t) + def __add__(self, other): + return added(self, other) + + +class added(differentiable_functional): + def __init__(self, a, b): + self.a = a + self.b = b + + def __call__(self, fields): + a = self.a(fields) + b = self.b(fields) + # g.message("Action",a,b) + return a + b + + def gradient(self, fields, dfields): + a_grad = self.a.gradient(fields, dfields) + b_grad = self.b.gradient(fields, dfields) + return [g(x + y) for x, y in zip(a_grad, b_grad)] + class transformed(differentiable_functional): def __init__(self, f, t):