-
Notifications
You must be signed in to change notification settings - Fork 7
/
test_dynamic.mojo
55 lines (46 loc) · 1.79 KB
/
test_dynamic.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from autograd import (
Tensor,
sin,
relu,
mse,
)
from autograd.utils.shape import shape
#######################################################################################################################
# main function: Testing functionality of the Engine... Dynamic Computation Graph (with conditional model architecture)
########################################################################################################################
fn main() raises:
# init params
let W1 = Tensor(shape(1, 64)).randhe().requires_grad()
let W2 = Tensor(shape(64, 64)).randhe().requires_grad()
let W3 = Tensor(shape(64, 1)).randhe().requires_grad()
let W_opt = Tensor(shape(64, 64)).randhe().requires_grad()
let b1 = Tensor(shape(64)).randhe().requires_grad()
let b2 = Tensor(shape(64)).randhe().requires_grad()
let b3 = Tensor(shape(1)).randhe().requires_grad()
let b_opt = Tensor(shape(64)).randhe().requires_grad()
# training
var avg_loss = Float32(0.0)
let every = 1000
let num_epochs = 20000
for epoch in range(1, num_epochs + 1):
# set input and true values
let input = Tensor(shape(32, 1)).randu(0, 1).dynamic()
let true_vals = sin(15.0 * input)
# define model architecture
var x = relu(input @ W1 + b1)
x = relu(x @ W2 + b2)
if epoch < 100:
x = relu(x @ W_opt + b_opt)
x = x @ W3 + b3
let loss = mse(x, true_vals)
# print progress
avg_loss += loss[0]
if epoch % every == 0:
print("Epoch:", epoch, " Avg Loss: ", avg_loss / every)
avg_loss = 0.0
# compute gradients and optimize
loss.backward()
loss.optimize(0.01, "sgd")
# clear graph
loss.clear()
input.free()