From 66c37a7dfb79015ebb43ba78287396104c29ff33 Mon Sep 17 00:00:00 2001 From: Cas Bex Date: Mon, 25 Sep 2023 13:02:50 +0200 Subject: [PATCH] Modify Experiment --- .../experiments/DQN/JuliaRL_NFQ_CartPole.jl | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl index d9690a031..de486114d 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl @@ -21,31 +21,32 @@ function RLCore.Experiment( ::Val{:CartPole}, seed = 123, ) + rng = StableRNG(seed) env = CartPoleEnv(; T=Float32, rng=rng) - ns, na = length(state(env)), length(first(action_space(env))) + ns, na = length(state(env)), length(action_space(env)) agent = Agent( policy=QBasedPolicy( learner=NFQ( - action_space=action_space(env), approximator=Approximator( model=Chain( - Dense(ns+na, 5, σ; init=glorot_uniform(rng)), - Dense(5, 5, σ; init=glorot_uniform(rng)), - Dense(5, 1; init=glorot_uniform(rng)), - ), - optimiser=RMSProp() + Dense(ns, 32, σ; init=glorot_uniform(rng)), + Dense(32, 32, relu; init=glorot_uniform(rng)), + Dense(32, na; init=glorot_uniform(rng)), + )|>gpu, + optimiser=RMSProp(), ), loss_function=mse, - epochs=100, + epochs=10, num_iterations=10, γ = 0.95f0 ), explorer=EpsilonGreedyExplorer( kind=:exp, ϵ_stable=0.001, - warmup_steps=500, + warmup_steps=1000, + decay_steps=3000, rng=rng, ), ), @@ -53,22 +54,24 @@ function RLCore.Experiment( container=CircularArraySARTSTraces( capacity=10_000, state=Float32 => (ns,), - action=Float32 => (na,), ), sampler=BatchSampler{SS′ART}( - batch_size=128, + batch_size=2048, rng=rng ), controller=InsertSampleRatioController( - threshold=100, - n_inserted=-1 + threshold=1000, + ratio=1/10, + n_sampled=-1 ) ) ) stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) hook = TotalRewardPerEpisode() + Experiment(agent, env, stop_condition, hook) + end #+ tangle=false