Skip to content

Commit

Permalink
Made tests more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
nic-barbara committed Aug 15, 2023
1 parent f5775d4 commit 0af7067
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions examples/GPU/test_speed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ using RobustNeuralNetworks

rng = Xoshiro(42)

# TODO: Probably better to send params to GPU then build off that.
function test_ren_speed(device, construct, args...; nu=4, nx=5, nv=10, ny=4,
nl=tanh, batches=4, tmax=3, is_diff=false, T=Float32,
do_time=true)
function test_ren_gpu(device, construct, args...; nu=4, nx=5, nv=10, ny=4,
nl=tanh, batches=4, tmax=3, is_diff=false, T=Float32,
do_time=true)

# Build the ren
model = construct{T}(nu, nx, nv, ny, args...; nl, rng)
is_diff && (model = DiffREN(model) |> device)
model = construct{T}(nu, nx, nv, ny, args...; nl, rng) |> device
is_diff && (model = DiffREN(model))

# Create dummy data
us = [randn(rng, T, nu, batches) for _ in 1:tmax] |> device
Expand All @@ -27,7 +26,7 @@ function test_ren_speed(device, construct, args...; nu=4, nx=5, nv=10, ny=4,

# Dummy loss function
function loss(model, x, us, ys)
m = is_diff ? model : (REN(model) |> device)
m = is_diff ? model : REN(model)
J = 0
for t in 1:tmax
x, y = m(x, us[t])
Expand Down Expand Up @@ -89,5 +88,5 @@ function test_rens(device)
return nothing
end

test_rens(cpu)
# test_rens(cpu)
test_rens(gpu)

0 comments on commit 0af7067

Please sign in to comment.