diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl index cb0ac9600..e1de1d169 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl @@ -42,6 +42,8 @@ function RLBase.optimise!( a = CartesianIndex.(a, 1:batch_size) k, p = batch.key, batch.priority p′ = similar(p) + s, s′, a, r, t = send_to_device(device(Q), (s, s′, a, r, t)) + k, p, p′ = send_to_device(device(Q), (k, p, p′)) w = 1.0f0 ./ ((p .+ 1.0f-10) .^ β) w ./= maximum(w)