From b613796e4b00414bb8015f2cc7872efee6a3f94a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 22:58:52 -0700 Subject: [PATCH] perf: use the permuted formulation --- bench/lux.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bench/lux.jl b/bench/lux.jl index b514f1f..a4b4c50 100644 --- a/bench/lux.jl +++ b/bench/lux.jl @@ -11,9 +11,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) function train!(loss, backend, model, ps, st, data; epochs=10) l1 = loss(model, ps, st, first(data)) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) for _ in 1:epochs, (x, y) in data - _, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate) + _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end l2 = loss(model, ps, st, first(data)) @@ -25,14 +25,14 @@ end n_points = 128 batch_size = 64 -x = rand(Float32, 1, n_points, batch_size); -y = rand(Float32, 1, n_points, batch_size); +x = rand(Float32, n_points, 1, batch_size); +y = rand(Float32, n_points, 1, batch_size); data = [(x, y)]; t_fwd = zeros(5) t_train = zeros(5) for i in 1:5 chs = (1, 128, fill(64, i)..., 128, 1) - model = FourierNeuralOperator(gelu; chs=chs, modes=(16,)) + model = FourierNeuralOperator(gelu; chs, modes=(16,), permuted=Val(true)) ps, st = Lux.setup(rng, model) model(x, ps, st) # TTFX