Skip to content

Commit

Permalink
Modify Experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
CasBex committed Sep 25, 2023
1 parent 9294980 commit 66c37a7
Showing 1 changed file with 16 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,57 @@ function RLCore.Experiment(
::Val{:CartPole},
seed = 123,
)

rng = StableRNG(seed)
env = CartPoleEnv(; T=Float32, rng=rng)
ns, na = length(state(env)), length(first(action_space(env)))
ns, na = length(state(env)), length(action_space(env))

agent = Agent(
policy=QBasedPolicy(
learner=NFQ(
action_space=action_space(env),
approximator=Approximator(
model=Chain(
Dense(ns+na, 5, σ; init=glorot_uniform(rng)),
Dense(5, 5, σ; init=glorot_uniform(rng)),
Dense(5, 1; init=glorot_uniform(rng)),
),
optimiser=RMSProp()
Dense(ns, 32, σ; init=glorot_uniform(rng)),
Dense(32, 32, relu; init=glorot_uniform(rng)),
Dense(32, na; init=glorot_uniform(rng)),
)|>gpu,
optimiser=RMSProp(),
),
loss_function=mse,
epochs=100,
epochs=10,
num_iterations=10,
γ = 0.95f0
),
explorer=EpsilonGreedyExplorer(
kind=:exp,
ϵ_stable=0.001,
warmup_steps=500,
warmup_steps=1000,
decay_steps=3000,
rng=rng,
),
),
trajectory=Trajectory(
container=CircularArraySARTSTraces(
capacity=10_000,
state=Float32 => (ns,),
action=Float32 => (na,),
),
sampler=BatchSampler{SS′ART}(
batch_size=128,
batch_size=2048,
rng=rng
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
threshold=1000,
ratio=1/10,
n_sampled=-1
)
)
)

stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()

Experiment(agent, env, stop_condition, hook)

end

#+ tangle=false
Expand Down

0 comments on commit 66c37a7

Please sign in to comment.