From f10811f6c54d33a28009621b69f902ac395467d9 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 12 Sep 2018 18:25:58 -0700 Subject: [PATCH 1/3] started enabling tests (WIP) --- test/REQUIRE | 1 + test/runtests.jl | 26 +++++++++++++------------- test/test_alpha_policy.jl | 5 +++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/test/REQUIRE b/test/REQUIRE index 8d06ae1..8fcd93b 100644 --- a/test/REQUIRE +++ b/test/REQUIRE @@ -1 +1,2 @@ POMDPModels +POMDPSimulators diff --git a/test/runtests.jl b/test/runtests.jl index cb8f203..b5761c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,21 +2,21 @@ using Test using POMDPPolicies using POMDPs using BeliefUpdaters -# using POMDPSimulators +using POMDPSimulators using POMDPModels -# @testset "alpha" begin -# include("test_alpha_policy.jl") -# end +@testset "alpha" begin + include("test_alpha_policy.jl") +end @testset "function" begin include("test_function_policy.jl") end -# @testset "stochastic" begin -# include("test_stochastic_policy.jl") -# end -# @testset "utility" begin -# include("test_utility_wrapper.jl") -# end -# @testset "vector" begin -# include("test_vector_policy.jl") -# end +@testset "stochastic" begin + include("test_stochastic_policy.jl") +end +@testset "utility" begin + include("test_utility_wrapper.jl") +end +@testset "vector" begin + include("test_vector_policy.jl") +end diff --git a/test/test_alpha_policy.jl b/test/test_alpha_policy.jl index bb06b71..2941703 100644 --- a/test/test_alpha_policy.jl +++ b/test/test_alpha_policy.jl @@ -5,12 +5,13 @@ let b0 = initialize_belief(bu, initialstate_distribution(pomdp)) # these values were gotten from FIB.jl - alphas = [-29.4557 -36.5093; -19.4557 -16.0629] + # alphas = [-29.4557 -36.5093; -19.4557 -16.0629] + alphas = [ -16.0629 -19.4557; -36.5093 -29.4557] policy = AlphaVectorPolicy(pomdp, alphas) # initial belief is 100% confidence in baby not being hungry @test isapprox(value(policy, b0), -16.0629) - @test isapprox(value(policy, [0.0,1.0]), -16.0629) + @test isapprox(value(policy, [1.0,0.0]), -16.0629) # because baby isn't hungry, policy should not feed (return false) @test action(policy, b0) == false From 80d81773f410e579dbec229d89b48a40414c2cbc Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 13 Sep 2018 10:36:25 -0700 Subject: [PATCH 2/3] tests work --- src/stochastic.jl | 2 +- src/vector.jl | 18 +++++++++--------- test/test_function_policy.jl | 2 +- test/test_stochastic_policy.jl | 2 +- test/test_utility_wrapper.jl | 2 +- test/test_vector_policy.jl | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/stochastic.jl b/src/stochastic.jl index 68c424a..42e4eb3 100644 --- a/src/stochastic.jl +++ b/src/stochastic.jl @@ -27,7 +27,7 @@ end CategoricalTabularPolicy(mdp::Union{POMDP,MDP}; rng=Random.GLOBAL_RNG) = CategoricalTabularPolicy(StochasticPolicy(Weights(zeros(n_actions(mdp)))), ValuePolicy(mdp)) function action(policy::CategoricalTabularPolicy, s) - policy.stochastic.distribution = Weights(policy.value.value_table[state_index(policy.value.mdp, s),:]) + policy.stochastic.distribution = Weights(policy.value.value_table[stateindex(policy.value.mdp, s),:]) return policy.value.act[sample(policy.stochastic.rng, policy.stochastic.distribution)] end diff --git a/src/vector.jl b/src/vector.jl index fa8566d..1cdaa92 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -2,14 +2,14 @@ # maintained by @zsunberg and @etotheipluspi """ -A generic MDP policy that consists of a vector of actions. The entry at `state_index(mdp, s)` is the action that will be taken in state `s`. +A generic MDP policy that consists of a vector of actions. The entry at `stateindex(mdp, s)` is the action that will be taken in state `s`. """ mutable struct VectorPolicy{S,A} <: Policy mdp::MDP{S,A} act::Vector{A} end -action(p::VectorPolicy, s) = p.act[state_index(p.mdp, s)] +action(p::VectorPolicy, s) = p.act[stateindex(p.mdp, s)] action(p::VectorPolicy, s, a) = action(p, s) """ @@ -25,19 +25,19 @@ end """ -A generic MDP policy that consists of a value table. The entry at `state_index(mdp, s)` is the action that will be taken in state `s`. +A generic MDP policy that consists of a value table. The entry at `stateindex(mdp, s)` is the action that will be taken in state `s`. """ -mutable struct ValuePolicy{A} <: Policy - mdp::Union{MDP,POMDP} - value_table::Matrix{Float64} +mutable struct ValuePolicy{P<:Union{POMDP,MDP}, T<:AbstractMatrix{Float64}, A} <: Policy + mdp::P + value_table::T act::Vector{A} end -function ValuePolicy(mdp::Union{MDP,POMDP}) +function ValuePolicy(mdp::Union{MDP,POMDP}, value_table = zeros(n_states(mdp), n_actions(mdp))) acts = Any[] for a in actions(mdp) push!(acts, a) end - return ValuePolicy(mdp, zeros(n_states(mdp), n_actions(mdp)), acts) + return ValuePolicy(mdp, value_table, acts) end -action(p::ValuePolicy, s) = p.act[argmax(p.value_table[state_index(p.mdp, s),:])] +action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])] diff --git a/test/test_function_policy.jl b/test/test_function_policy.jl index e60ac74..1e143d1 100644 --- a/test/test_function_policy.jl +++ b/test/test_function_policy.jl @@ -3,7 +3,7 @@ let @test action(p, true) == false s = FunctionSolver(x::Int->2*x) - p = solve(s, GridWorld()) + p = solve(s, LegacyGridWorld()) @test action(p, 10) == 20 @test action(p, 100) == 200 updater(p) # just to make sure this doesn't error diff --git a/test/test_stochastic_policy.jl b/test/test_stochastic_policy.jl index 3656b7a..d372ef7 100644 --- a/test/test_stochastic_policy.jl +++ b/test/test_stochastic_policy.jl @@ -2,7 +2,7 @@ let using POMDPModels -problem = GridWorld() +problem = LegacyGridWorld() policy = UniformRandomPolicy(problem) sim = RolloutSimulator(max_steps=10) diff --git a/test/test_utility_wrapper.jl b/test/test_utility_wrapper.jl index ee7b3f3..cb9cc0b 100644 --- a/test/test_utility_wrapper.jl +++ b/test/test_utility_wrapper.jl @@ -1,7 +1,7 @@ let using POMDPModels - mdp = GridWorld() + mdp = LegacyGridWorld() policy = RandomPolicy(mdp) counts = Dict(a=>0 for a in actions(mdp)) diff --git a/test/test_vector_policy.jl b/test/test_vector_policy.jl index 7a9cb6d..33d9f64 100644 --- a/test/test_vector_policy.jl +++ b/test/test_vector_policy.jl @@ -1,5 +1,5 @@ let - gw = GridWorld(sx=2, sy=2, rs=[GridWorldState(1,1)], rv=[10.0]) + gw = LegacyGridWorld(sx=2, sy=2, rs=[GridWorldState(1,1)], rv=[10.0]) pvec = fill(GridWorldAction(:left), 5) From fd7f13a976439e1b21029ac6be8776d85b8ef399 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 13 Sep 2018 10:59:15 -0700 Subject: [PATCH 3/3] added random policy tests --- test/runtests.jl | 4 ++++ test/test_random_solver.jl | 13 +++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 test/test_random_solver.jl diff --git a/test/runtests.jl b/test/runtests.jl index b5761c0..13276cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using POMDPs using BeliefUpdaters using POMDPSimulators using POMDPModels +using Random @testset "alpha" begin include("test_alpha_policy.jl") @@ -20,3 +21,6 @@ end @testset "vector" begin include("test_vector_policy.jl") end +@testset "random" begin + include("test_random_solver.jl") +end diff --git a/test/test_random_solver.jl b/test/test_random_solver.jl new file mode 100644 index 0000000..b278cc1 --- /dev/null +++ b/test/test_random_solver.jl @@ -0,0 +1,13 @@ +let + problem = BabyPOMDP() + + solver = RandomSolver(rng=MersenneTwister(1)) + + policy = solve(solver, problem) + + sim = RolloutSimulator(max_steps=10, rng=MersenneTwister(1)) + + r = simulate(sim, problem, policy, updater(policy), initial_state_distribution(problem)) + + @test isapprox(r, -27.27829, atol=1e-3) +end