diff --git a/bench/Project.toml b/bench/Project.toml index fb9708d..5b4ff71 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -1,5 +1,6 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/bench/lux.jl b/bench/lux.jl index b514f1f..b5959de 100644 --- a/bench/lux.jl +++ b/bench/lux.jl @@ -2,7 +2,7 @@ using ThreadPinning pinthreads(:cores) threadinfo() -using BenchmarkTools, NeuralOperators, Random, Optimisers, Zygote +using BenchmarkTools, NeuralOperators, Random, Optimisers, Zygote, Lux rng = Xoshiro(1234) @@ -11,9 +11,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) function train!(loss, backend, model, ps, st, data; epochs=10) l1 = loss(model, ps, st, first(data)) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) for _ in 1:epochs, (x, y) in data - _, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate) + _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end l2 = loss(model, ps, st, first(data)) @@ -25,14 +25,14 @@ end n_points = 128 batch_size = 64 -x = rand(Float32, 1, n_points, batch_size); -y = rand(Float32, 1, n_points, batch_size); +x = rand(Float32, n_points, 1, batch_size); +y = rand(Float32, n_points, 1, batch_size); data = [(x, y)]; t_fwd = zeros(5) t_train = zeros(5) for i in 1:5 chs = (1, 128, fill(64, i)..., 128, 1) - model = FourierNeuralOperator(gelu; chs=chs, modes=(16,)) + model = FourierNeuralOperator(gelu; chs, modes=(16,), permuted=Val(true)) ps, st = Lux.setup(rng, model) model(x, ps, st) # TTFX diff --git a/test/Project.toml b/test/Project.toml index 1ddcfa0..afbefaf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,8 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"