Skip to content

Commit

Permalink
add diff func
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jun 2, 2024
1 parent ff1ff5d commit 9d8aad0
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lib/gpt/core/group/differentiable_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ def assert_gradient_error(self, rng, fields, dfields, epsilon_approx, epsilon_as
def transformed(self, t):
return transformed(self, t)

def __add__(self, other):
return added(self, other)


class added(differentiable_functional):
def __init__(self, a, b):
self.a = a
self.b = b

def __call__(self, fields):
a = self.a(fields)
b = self.b(fields)
# g.message("Action",a,b)
return a + b

def gradient(self, fields, dfields):
a_grad = self.a.gradient(fields, dfields)
b_grad = self.b.gradient(fields, dfields)
return [g(x + y) for x, y in zip(a_grad, b_grad)]


class transformed(differentiable_functional):
def __init__(self, f, t):
Expand Down

0 comments on commit 9d8aad0

Please sign in to comment.