Skip to content

Commit

Permalink
auto differentiable functional
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 31, 2023
1 parent c9545ee commit c93b1a0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
26 changes: 26 additions & 0 deletions lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ def traverse(nodes, n, visited=None):
return forward_free



class node_differentiable_functional(g.group.differentiable_functional):

Check failure on line 54 in lib/gpt/ad/reverse/node.py

View workflow job for this annotation

GitHub Actions / lint

E303:too many blank lines (3)
def __init__(self, node, arguments):
self.node = node
self.arguments = arguments

def __call__(self, fields):
assert len(fields) == len(self.arguments)
for i in range(len(fields)):
self.arguments[i].value @= fields[i]
return self.node(with_gradients=False).real

def gradient(self, fields, dfields):
for a in self.arguments:
a.with_gradient = False
indices = [fields.index(df) for df in dfields]
for i in indices:
self.arguments[i].gradient = None
self.arguments[i].with_gradient = True
self.node()
return [self.arguments[i].gradient for i in indices]

# gctr = 0


Expand Down Expand Up @@ -179,6 +201,10 @@ def __call__(self, with_gradients=True):
nodes = None
return self.value

def functional(self, *arguments):
return node_differentiable_functional(self, arguments)



def node(x, with_gradient=True):

Check failure on line 209 in lib/gpt/ad/reverse/node.py

View workflow job for this annotation

GitHub Actions / lint

E303:too many blank lines (3)
return node_base(x, with_gradient=with_gradient)
3 changes: 0 additions & 3 deletions lib/gpt/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def __add__(self, other):
def __truediv__(self, other):
return tensor(self.array / other, self.otype)

def __eq__(self, other):
return np.array_equal(self.array, other.array)

def __neg__(self):
return tensor(-self.array, self.otype)

Expand Down
20 changes: 1 addition & 19 deletions tests/ad/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,7 @@
assert err < 1e-4

# create something to minimize
class fnc(g.group.differentiable_functional):
def __init__(self):
pass

def __call__(self, fields):
global a1, b1, a2, b2, t1
a1.value @= fields[0]
b1.value @= fields[1]
a2.value @= fields[2]
b2.value @= fields[3]
t1.value @= fields[4]
return c(with_gradients=False).real

def gradient(self, fields, dfields):
c()
assert dfields == fields
return [a1.gradient, b1.gradient, a2.gradient, b2.gradient, t1.gradient]

f = fnc()
f = c.functional(a1, b1, a2, b2, t1)
ff = [a1.value, b1.value, a2.value, b2.value, t1.value]
v0 = f(ff)
opt = g.algorithms.optimize.adam(maxiter=40, eps=1e-7, alpha=learn_rate)
Expand Down

0 comments on commit c93b1a0

Please sign in to comment.