Skip to content

Commit

Permalink
Merge pull request #2 from JuliaPOMDP/enabling_tests
Browse files Browse the repository at this point in the history
[WIP] started enabling tests
  • Loading branch information
zsunberg authored Sep 13, 2018
2 parents 704fd30 + fd7f13a commit 0a1171f
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/stochastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
CategoricalTabularPolicy(mdp::Union{POMDP,MDP}; rng=Random.GLOBAL_RNG) = CategoricalTabularPolicy(StochasticPolicy(Weights(zeros(n_actions(mdp)))), ValuePolicy(mdp))

function action(policy::CategoricalTabularPolicy, s)
policy.stochastic.distribution = Weights(policy.value.value_table[state_index(policy.value.mdp, s),:])
policy.stochastic.distribution = Weights(policy.value.value_table[stateindex(policy.value.mdp, s),:])
return policy.value.act[sample(policy.stochastic.rng, policy.stochastic.distribution)]
end

Expand Down
18 changes: 9 additions & 9 deletions src/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# maintained by @zsunberg and @etotheipluspi

"""
A generic MDP policy that consists of a vector of actions. The entry at `state_index(mdp, s)` is the action that will be taken in state `s`.
A generic MDP policy that consists of a vector of actions. The entry at `stateindex(mdp, s)` is the action that will be taken in state `s`.
"""
mutable struct VectorPolicy{S,A} <: Policy
mdp::MDP{S,A}
act::Vector{A}
end

action(p::VectorPolicy, s) = p.act[state_index(p.mdp, s)]
action(p::VectorPolicy, s) = p.act[stateindex(p.mdp, s)]
action(p::VectorPolicy, s, a) = action(p, s)

"""
Expand All @@ -25,19 +25,19 @@ end


"""
A generic MDP policy that consists of a value table. The entry at `state_index(mdp, s)` is the action that will be taken in state `s`.
A generic MDP policy that consists of a value table. The entry at `stateindex(mdp, s)` is the action that will be taken in state `s`.
"""
mutable struct ValuePolicy{A} <: Policy
mdp::Union{MDP,POMDP}
value_table::Matrix{Float64}
mutable struct ValuePolicy{P<:Union{POMDP,MDP}, T<:AbstractMatrix{Float64}, A} <: Policy
mdp::P
value_table::T
act::Vector{A}
end
function ValuePolicy(mdp::Union{MDP,POMDP})
function ValuePolicy(mdp::Union{MDP,POMDP}, value_table = zeros(n_states(mdp), n_actions(mdp)))
acts = Any[]
for a in actions(mdp)
push!(acts, a)
end
return ValuePolicy(mdp, zeros(n_states(mdp), n_actions(mdp)), acts)
return ValuePolicy(mdp, value_table, acts)
end

action(p::ValuePolicy, s) = p.act[argmax(p.value_table[state_index(p.mdp, s),:])]
action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])]
1 change: 1 addition & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
POMDPModels
POMDPSimulators
30 changes: 17 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@ using Test
using POMDPPolicies
using POMDPs
using BeliefUpdaters
# using POMDPSimulators
using POMDPSimulators
using POMDPModels
using Random

# @testset "alpha" begin
# include("test_alpha_policy.jl")
# end
@testset "alpha" begin
include("test_alpha_policy.jl")
end
@testset "function" begin
include("test_function_policy.jl")
end
# @testset "stochastic" begin
# include("test_stochastic_policy.jl")
# end
# @testset "utility" begin
# include("test_utility_wrapper.jl")
# end
# @testset "vector" begin
# include("test_vector_policy.jl")
# end
@testset "stochastic" begin
include("test_stochastic_policy.jl")
end
@testset "utility" begin
include("test_utility_wrapper.jl")
end
@testset "vector" begin
include("test_vector_policy.jl")
end
@testset "random" begin
include("test_random_solver.jl")
end
5 changes: 3 additions & 2 deletions test/test_alpha_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ let
b0 = initialize_belief(bu, initialstate_distribution(pomdp))

# these values were gotten from FIB.jl
alphas = [-29.4557 -36.5093; -19.4557 -16.0629]
# alphas = [-29.4557 -36.5093; -19.4557 -16.0629]
alphas = [ -16.0629 -19.4557; -36.5093 -29.4557]
policy = AlphaVectorPolicy(pomdp, alphas)

# initial belief is 100% confidence in baby not being hungry
@test isapprox(value(policy, b0), -16.0629)
@test isapprox(value(policy, [0.0,1.0]), -16.0629)
@test isapprox(value(policy, [1.0,0.0]), -16.0629)

# because baby isn't hungry, policy should not feed (return false)
@test action(policy, b0) == false
Expand Down
2 changes: 1 addition & 1 deletion test/test_function_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ let
@test action(p, true) == false

s = FunctionSolver(x::Int->2*x)
p = solve(s, GridWorld())
p = solve(s, LegacyGridWorld())
@test action(p, 10) == 20
@test action(p, 100) == 200
updater(p) # just to make sure this doesn't error
Expand Down
13 changes: 13 additions & 0 deletions test/test_random_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
let
problem = BabyPOMDP()

solver = RandomSolver(rng=MersenneTwister(1))

policy = solve(solver, problem)

sim = RolloutSimulator(max_steps=10, rng=MersenneTwister(1))

r = simulate(sim, problem, policy, updater(policy), initial_state_distribution(problem))

@test isapprox(r, -27.27829, atol=1e-3)
end
2 changes: 1 addition & 1 deletion test/test_stochastic_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ let

using POMDPModels

problem = GridWorld()
problem = LegacyGridWorld()

policy = UniformRandomPolicy(problem)
sim = RolloutSimulator(max_steps=10)
Expand Down
2 changes: 1 addition & 1 deletion test/test_utility_wrapper.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
let
using POMDPModels

mdp = GridWorld()
mdp = LegacyGridWorld()
policy = RandomPolicy(mdp)
counts = Dict(a=>0 for a in actions(mdp))

Expand Down
2 changes: 1 addition & 1 deletion test/test_vector_policy.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
let
gw = GridWorld(sx=2, sy=2, rs=[GridWorldState(1,1)], rv=[10.0])
gw = LegacyGridWorld(sx=2, sy=2, rs=[GridWorldState(1,1)], rv=[10.0])

pvec = fill(GridWorldAction(:left), 5)

Expand Down

0 comments on commit 0a1171f

Please sign in to comment.