From cd6d2c63ca81f9f9b1a1ffb204580ad2ec6cdd56 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Fri, 22 Jul 2022 09:50:16 -0700 Subject: [PATCH] switched to using POMDPTools --- Project.toml | 12 ++-- README.md | 15 +---- src/POMDPPolicies.jl | 84 ++--------------------- src/alpha_vector.jl | 124 ---------------------------------- src/exploration_policies.jl | 128 ------------------------------------ src/function.jl | 27 -------- src/playback.jl | 45 ------------- src/pretty_printing.jl | 64 ------------------ src/random.jl | 48 -------------- src/stochastic.jl | 59 ----------------- src/utility_wrapper.jl | 83 ----------------------- src/vector.jl | 72 -------------------- 12 files changed, 16 insertions(+), 745 deletions(-) delete mode 100644 src/alpha_vector.jl delete mode 100644 src/exploration_policies.jl delete mode 100644 src/function.jl delete mode 100644 src/playback.jl delete mode 100644 src/pretty_printing.jl delete mode 100644 src/random.jl delete mode 100644 src/stochastic.jl delete mode 100644 src/utility_wrapper.jl delete mode 100644 src/vector.jl diff --git a/Project.toml b/Project.toml index 2330c97..501dc83 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,15 @@ name = "POMDPPolicies" uuid = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" -version = "0.4.2" +version = "0.4.3" [deps] -BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" +POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -17,16 +17,20 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" BeliefUpdaters = "0.1, 0.2" Distributions = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" POMDPModelTools = "0.2, 0.3" +POMDPTools = "0.1" POMDPs = "0.7.3, 0.8, 0.9" Parameters = "0.12" +Reexport = "1" StatsBase = "0.26,0.27,0.28,0.29,0.30,0.31,0.32, 0.33" julia = "1" [extras] +BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" +POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random", "POMDPSimulators", "POMDPModels"] +test = ["Test", "Random", "POMDPSimulators", "POMDPModels", "POMDPModelTools", "BeliefUpdaters"] diff --git a/README.md b/README.md index 8e73174..dd21ec8 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,3 @@ -# POMDPPolicies +# ~~POMDPPolicies~~ -[![Build Status](https://travis-ci.org/JuliaPOMDP/POMDPPolicies.jl.svg?branch=master)](https://travis-ci.org/JuliaPOMDP/POMDPPolicies.jl) -[![Coverage Status](https://coveralls.io/repos/github/JuliaPOMDP/POMDPPolicies.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaPOMDP/POMDPPolicies.jl?branch=master) -[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaPOMDP.github.io/POMDPPolicies.jl/latest) - -A collection of default policy types for [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl). - -# Installation - -```julia -using Pkg -Pkg.add("POMDPPolicies") -``` +POMDPPolicies is deprecated and the functionality has been moved to [POMDPTools](https://github.com/JuliaPOMDP/POMDPs.jl/tree/master/lib/POMDPTools). Please use that package instead. diff --git a/src/POMDPPolicies.jl b/src/POMDPPolicies.jl index 95100cf..0c44c57 100644 --- a/src/POMDPPolicies.jl +++ b/src/POMDPPolicies.jl @@ -1,85 +1,13 @@ module POMDPPolicies -using LinearAlgebra -using Random -using StatsBase # for Weights -using SparseArrays # for sparse vectors in alpha_vector.jl -using Parameters -using Distributions # For logpdf extenstion in playback policy +Base.depwarn(""" + The functionality of POMDPPolicies has been moved to POMDPTools. -using POMDPs -import POMDPs: action, value, solve, updater + Please replace `using POMDPPolicies` with `using POMDPTools`. + """, :POMDPPolicies) -using BeliefUpdaters -using POMDPModelTools +using Reexport -using Base.Iterators # for take - -""" - actionvalues(p::Policy, s) - -returns the values of each action at state s in a vector -""" -function actionvalues end - -export - actionvalues - -export - AlphaVectorPolicy, - alphavectors, - alphapairs - -include("alpha_vector.jl") - -export - FunctionPolicy, - FunctionSolver - -include("function.jl") - -export - RandomPolicy, - RandomSolver - -include("random.jl") - -export - VectorPolicy, - VectorSolver, - ValuePolicy - -include("vector.jl") - -export - StochasticPolicy, - UniformRandomPolicy, - CategoricalTabularPolicy - -include("stochastic.jl") - -export LinearDecaySchedule, - EpsGreedyPolicy, - SoftmaxPolicy, - ExplorationPolicy, - loginfo - -include("exploration_policies.jl") - -export - PolicyWrapper, - payload - -include("utility_wrapper.jl") - -export - showpolicy - -include("pretty_printing.jl") - -export - PlaybackPolicy - -include("playback.jl") +@reexport using POMDPTools.Policies end diff --git a/src/alpha_vector.jl b/src/alpha_vector.jl deleted file mode 100644 index 15ce623..0000000 --- a/src/alpha_vector.jl +++ /dev/null @@ -1,124 +0,0 @@ -""" - AlphaVectorPolicy(pomdp::POMDP, alphas, action_map) - -Construct a policy from alpha vectors. - -# Arguments -- `alphas`: an |S| x (number of alpha vecs) matrix or a vector of alpha vectors. -- `action_map`: a vector of the actions correponding to each alpha vector - - AlphaVectorPolicy{P<:POMDP, A} - -Represents a policy with a set of alpha vectors. - -Use `action` to get the best action for a belief, and `alphavectors` and `alphapairs` to - -# Fields -- `pomdp::P` the POMDP problem -- `n_states::Int` the number of states in the POMDP -- `alphas::Vector{Vector{Float64}}` the list of alpha vectors -- `action_map::Vector{A}` a list of action corresponding to the alpha vectors -""" -struct AlphaVectorPolicy{P<:POMDP, A} <: Policy - pomdp::P # needed for mapping states to locations in alpha vectors - n_states::Int - alphas::Vector{Vector{Float64}} - action_map::Vector{A} -end - -@deprecate AlphaVectorPolicy(pomdp::POMDP, alphas) AlphaVectorPolicy(pomdp, alphas, ordered_actions(pomdp)) - -function AlphaVectorPolicy(m::POMDP, alphas::AbstractVector, amap) - AlphaVectorPolicy(m, length(states(m)), alphas, - convert(Vector{actiontype(m)}, amap)) -end - -# assumes alphas is |S| x (number of alpha vecs) -function AlphaVectorPolicy(p::POMDP, alphas::Matrix{Float64}, action_map) - # turn alphas into vector of vectors - num_actions = size(alphas, 2) - alpha_vecs = Vector{Float64}[] - for i = 1:num_actions - push!(alpha_vecs, vec(alphas[:,i])) - end - - AlphaVectorPolicy(p, length(states(p)), alpha_vecs, - convert(Vector{actiontype(p)}, action_map)) -end - -updater(p::AlphaVectorPolicy) = DiscreteUpdater(p.pomdp) - -""" -Return an iterator of alpha vector-action pairs in the policy. -""" -alphapairs(p::AlphaVectorPolicy) = (p.alphas[i]=>p.action_map[i] for i in 1:length(p.alphas)) - -""" -Return the alpha vectors. -""" -alphavectors(p::AlphaVectorPolicy) = p.alphas - -# The three functions below rely on beliefvec being implemented for the belief type -# Implementations of beliefvec are below -function value(p::AlphaVectorPolicy, b) - bvec = beliefvec(p.pomdp, p.n_states, b) - maximum(dot(bvec,a) for a in p.alphas) -end - -function action(p::AlphaVectorPolicy, b) - bvec = beliefvec(p.pomdp, p.n_states, b) - num_vectors = length(p.alphas) - best_idx = 1 - max_value = -Inf - for i = 1:num_vectors - temp_value = dot(bvec, p.alphas[i]) - if temp_value > max_value - max_value = temp_value - best_idx = i - end - end - return p.action_map[best_idx] -end - -function actionvalues(p::AlphaVectorPolicy, b) - bvec = beliefvec(p.pomdp, p.n_states, b) - num_vectors = length(p.alphas) - max_values = -Inf*ones(length(actions(p.pomdp))) - for i = 1:num_vectors - temp_value = dot(bvec, p.alphas[i]) - ai = actionindex(p.pomdp, p.action_map[i]) - if temp_value > max_values[ai] - max_values[ai] = temp_value - end - end - return max_values -end - -""" - POMDPPolicies.beliefvec(m::POMDP, n_states::Int, b) - -Return a vector-like representation of the belief `b` suitable for calculating the dot product with the alpha vectors. -""" -function beliefvec end - -function beliefvec(m::POMDP, n, b::SparseCat) - return sparsevec(collect(stateindex(m, s) for s in b.vals), collect(b.probs), n) -end -beliefvec(m::POMDP, n, b::DiscreteBelief) = b.b -beliefvec(m::POMDP, n, b::AbstractArray) = b - -function beliefvec(m::POMDP, n_states, b) - sup = support(b) - bvec = zeros(n_states) - for s in sup - bvec[stateindex(m, s)] = pdf(b, s) - end - return bvec -end - -function Base.push!(p::AlphaVectorPolicy, alpha::Vector{Float64}, a) - push!(p.alphas, alpha) - push!(p.action_map, a) -end - -@deprecate beliefvec(m::POMDP, b) beliefvec(m, length(states(m)), b) diff --git a/src/exploration_policies.jl b/src/exploration_policies.jl deleted file mode 100644 index 387ca74..0000000 --- a/src/exploration_policies.jl +++ /dev/null @@ -1,128 +0,0 @@ -""" - LinearDecaySchedule -A schedule that linearly decreases a value from `start` to `stop` in `steps` steps. -if the value is greater or equal to `stop`, it stays constant. - -# Constructor - -`LinearDecaySchedule(;start, stop, steps)` -""" -@with_kw struct LinearDecaySchedule{R<:Real} <: Function - start::R - stop::R - steps::Int -end - -function (schedule::LinearDecaySchedule)(k) - rate = (schedule.start - schedule.stop) / schedule.steps - val = schedule.start - k*rate - val = max(schedule.stop, val) -end - - -""" - ExplorationPolicy <: Policy -An abstract type for exploration policies. -Sampling from an exploration policy is done using `action(exploration_policy, on_policy, k, state)`. -`k` is a value that is used to determine the exploration parameter. It is usually a training step in a TD-learning algorithm. -""" -abstract type ExplorationPolicy <: Policy end - -""" - loginfo(::ExplorationPolicy, k) -returns information about an exploration policy, e.g. epsilon for e-greedy or temperature for softmax. -It is expected to return a namedtuple (e.g. (temperature=0.5)). `k` is the current training step that is used to compute the exploration parameter. -""" -function loginfo end - -""" - EpsGreedyPolicy <: ExplorationPolicy - -represents an epsilon greedy policy, sampling a random action with a probability `eps` or returning an action from a given policy otherwise. -The evolution of epsilon can be controlled using a schedule. This feature is useful for using those policies in reinforcement learning algorithms. - -# Constructor: - -`EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Union{Function, Float64}; rng=Random.GLOBAL_RNG, schedule=ConstantSchedule)` - -If a function is passed for `eps`, `eps(k)` is called to compute the value of epsilon when calling `action(exploration_policy, on_policy, k, s)`. - - -# Fields - -- `eps::Function` -- `rng::AbstractRNG` -- `actions::A` an indexable list of action -""" -struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy - eps::T - rng::R - actions::A -end - -function EpsGreedyPolicy(problem, eps::Function; - rng::AbstractRNG=Random.GLOBAL_RNG) - return EpsGreedyPolicy(eps, rng, actions(problem)) -end -function EpsGreedyPolicy(problem, eps::Real; - rng::AbstractRNG=Random.GLOBAL_RNG) - return EpsGreedyPolicy(x->eps, rng, actions(problem)) -end - - -function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s) - if rand(p.rng) < p.eps(k) - return rand(p.rng, p.actions) - else - return action(on_policy, s) - end -end - -loginfo(p::EpsGreedyPolicy, k) = (eps=p.eps(k),) - -# softmax -""" - SoftmaxPolicy <: ExplorationPolicy - -represents a softmax policy, sampling a random action according to a softmax function. -The softmax function converts the action values of the on policy into probabilities that are used for sampling. -A temperature parameter or function can be used to make the resulting distribution more or less wide. - -# Constructor - -`SoftmaxPolicy(problem, temperature::Union{Function, Float64}; rng=Random.GLOBAL_RNG)` - -If a function is passed for `temperature`, `temperature(k)` is called to compute the value of the temperature when calling `action(exploration_policy, on_policy, k, s)` - -# Fields - -- `temperature::Function` -- `rng::AbstractRNG` -- `actions::A` an indexable list of action - -""" -struct SoftmaxPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy - temperature::T - rng::R - actions::A -end - -function SoftmaxPolicy(problem, temperature::Function; - rng::AbstractRNG=Random.GLOBAL_RNG) - return SoftmaxPolicy(temperature, rng, actions(problem)) -end -function SoftmaxPolicy(problem, temperature::Real; - rng::AbstractRNG=Random.GLOBAL_RNG) - return SoftmaxPolicy(x->temperature, rng, actions(problem)) -end - -function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, k, s) - vals = actionvalues(on_policy, s) - vals ./= p.temperature(k) - maxval = maximum(vals) - exp_vals = exp.(vals .- maxval) - exp_vals /= sum(exp_vals) - return p.actions[sample(p.rng, Weights(exp_vals))] -end - -loginfo(p::SoftmaxPolicy, k) = (temperature=p.temperature(k),) diff --git a/src/function.jl b/src/function.jl deleted file mode 100644 index 5d0128d..0000000 --- a/src/function.jl +++ /dev/null @@ -1,27 +0,0 @@ -# FunctionPolicy -# A policy represented by a function -# maintained by @zsunberg - -""" -FunctionPolicy - -Policy `p=FunctionPolicy(f)` returns `f(x)` when `action(p, x)` is called. -""" -struct FunctionPolicy{F<:Function} <: Policy - f::F -end - -""" -FunctionSolver - -Solver for a FunctionPolicy. -""" -mutable struct FunctionSolver{F<:Function} <: Solver - f::F -end - -solve(s::FunctionSolver, mdp::Union{MDP,POMDP}) = FunctionPolicy(s.f) - -action(p::FunctionPolicy, x) = p.f(x) - -updater(p::FunctionPolicy) = PreviousObservationUpdater() diff --git a/src/playback.jl b/src/playback.jl deleted file mode 100644 index a085306..0000000 --- a/src/playback.jl +++ /dev/null @@ -1,45 +0,0 @@ - -""" - PlaybackPolicy{A<:AbstractArray, P<:Policy, V<:AbstractArray{<:Real}} -a policy that applies a fixed sequence of actions until they are all used and then falls back onto a backup policy until the end of the episode. - -Constructor: - - `PlaybackPolicy(actions::AbstractArray, backup_policy::Policy; logpdfs::AbstractArray{Float64, 1} = Float64[])` - -# Fields -- `actions::Vector{A}` a vector of actions to play back -- `backup_policy::Policy` the policy to use when all prescribed actions have been taken but the episode continues -- `logpdfs::Vector{Float64}` the log probability (density) of actions -- `i::Int64` the current action index -""" -mutable struct PlaybackPolicy{A<:AbstractArray, P<:Policy, V<:AbstractArray{<:Real}} <: Policy - actions::A - backup_policy::P - logpdfs::V - i::Int64 -end - -# Constructor for the PlaybackPolicy -PlaybackPolicy(actions::AbstractArray, backup_policy::Policy; logpdfs::AbstractArray{<:Real} = Float64[]) = PlaybackPolicy(actions, backup_policy, logpdfs, 1) - -# Action selection for the PlaybackPolicy -function POMDPs.action(p::PlaybackPolicy, s) - a = p.i <= length(p.actions) ? p.actions[p.i] : action(p.backup_policy, s) - p.i += 1 - a -end - -# Get the logpdf of the history from the playback policy and the backup policy -function Distributions.logpdf(p::PlaybackPolicy, h) - N = min(length(p.actions), length(h)) - # @assert all(collect(action_hist(h))[1:N] .== p.actions[1:N]) - @assert length(p.actions) == length(p.logpdfs) - if length(h) > N - return sum(p.logpdfs) + sum(logpdf(p.backup_policy, view(h, N+1:length(h)))) - else - return sum(p.logpdfs[1:N]) - end -end - - diff --git a/src/pretty_printing.jl b/src/pretty_printing.jl deleted file mode 100644 index b748012..0000000 --- a/src/pretty_printing.jl +++ /dev/null @@ -1,64 +0,0 @@ -""" - showpolicy([io], [mime], m::MDP, p::Policy) - showpolicy([io], [mime], statelist::AbstractVector, p::Policy) - showpolicy(...; pre=" ") - -Print the states in `m` or `statelist` and the actions from policy `p` corresponding to those states. - -For the MDP version, if `io[:limit]` is `true`, will only print enough states to fill the display. -""" -function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; pre=" ", kwargs...) - slist = nothing - truncated = false - limited = get(io, :limit, false) - rows = first(get(io, :displaysize, displaysize(io))) - rows -= 3 # Yuck! This magic number is also in Base.print_matrix - try - if limited && length(states(m)) > rows - slist = collect(take(states(m), rows-1)) - truncated = true - else - slist = collect(states(m)) - end - catch ex - @info("""Unable to pretty-print policy: - $(sprint(showerror, ex)) - """) - show(io, mime, m) - return show(io, mime, p) - end - showpolicy(io, mime, slist, p; pre=pre, kwargs...) - if truncated - print(io, '\n', pre, "…") - end -end - -function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ") - S = eltype(slist) - sa_con = IOContext(io, :compact => true) - - if !isempty(slist) - # print first element without a newline - print(io, pre) - print_sa(sa_con, first(slist), p, S) - - # print all other elements - for s in slist[2:end] - print(io, '\n', pre) - print_sa(sa_con, s, p, S) - end - end -end - -showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(io, MIME("text/plain"), m, p; kwargs...) -showpolicy(m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(stdout, m, p; kwargs...) - -function print_sa(io::IO, s, p::Policy, S::Type) - show(IOContext(io, :typeinfo => S), s) - print(io, " -> ") - try - show(io, action(p, s)) - catch ex - showerror(IOContext(io, :limit=>true), ex) - end -end diff --git a/src/random.jl b/src/random.jl deleted file mode 100644 index 0d39e02..0000000 --- a/src/random.jl +++ /dev/null @@ -1,48 +0,0 @@ -### RandomPolicy ### -# maintained by @zsunberg -""" - RandomPolicy{RNG<:AbstractRNG, P<:Union{POMDP,MDP}, U<:Updater} -a generic policy that uses the actions function to create a list of actions and then randomly samples an action from it. - -Constructor: - - `RandomPolicy(problem::Union{POMDP,MDP}; - rng=Random.GLOBAL_RNG, - updater=NothingUpdater())` - -# Fields -- `rng::RNG` a random number generator -- `probelm::P` the POMDP or MDP problem -- `updater::U` a belief updater (default to `NothingUpdater` in the above constructor) -""" -mutable struct RandomPolicy{RNG<:AbstractRNG, P<:Union{POMDP,MDP}, U<:Updater} <: Policy - rng::RNG - problem::P - updater::U # set this to use a custom updater, by default it will be a void updater -end -# The constructor below should be used to create the policy so that the action space is initialized correctly -RandomPolicy(problem::Union{POMDP,MDP}; - rng=Random.GLOBAL_RNG, - updater=NothingUpdater()) = RandomPolicy(rng, problem, updater) - -## policy execution ## -function action(policy::RandomPolicy, s) - return rand(policy.rng, actions(policy.problem, s)) -end - -function action(policy::RandomPolicy, b::Nothing) - return rand(policy.rng, actions(policy.problem)) -end - -## convenience functions ## -updater(policy::RandomPolicy) = policy.updater - - -""" -solver that produces a random policy -""" -mutable struct RandomSolver <: Solver - rng::AbstractRNG -end -RandomSolver(;rng=Random.GLOBAL_RNG) = RandomSolver(rng) -solve(solver::RandomSolver, problem::Union{POMDP,MDP}) = RandomPolicy(solver.rng, problem, NothingUpdater()) diff --git a/src/stochastic.jl b/src/stochastic.jl deleted file mode 100644 index 64228b5..0000000 --- a/src/stochastic.jl +++ /dev/null @@ -1,59 +0,0 @@ -### StochasticPolicy ### -# maintained by @etotheipluspi - -""" - StochasticPolicy{D, RNG <: AbstractRNG} - -Represents a stochastic policy. Action are sampled from an arbitrary distribution. - -Constructor: - - `StochasticPolicy(distribution; rng=Random.GLOBAL_RNG)` - -# Fields -- `distribution::D` -- `rng::RNG` a random number generator -""" -mutable struct StochasticPolicy{D, RNG <: AbstractRNG} <: Policy - distribution::D - rng::RNG -end -# The constructor below should be used to create the policy so that the action space is initialized correctly -StochasticPolicy(distribution; rng=Random.GLOBAL_RNG) = StochasticPolicy(distribution, rng) - -## policy execution ## -function action(policy::StochasticPolicy, s) - return rand(policy.rng, policy.distribution) -end - -## convenience functions ## -updater(policy::StochasticPolicy) = VoidUpdater() # since the stochastic policy does not depend on the belief - -# Samples actions uniformly -UniformRandomPolicy(problem, rng=Random.GLOBAL_RNG) = StochasticPolicy(actions(problem), rng) - -""" - CategoricalTabularPolicy - -represents a stochastic policy sampling an action from a categorical distribution with weights given by a `ValuePolicy` - -constructor: - -`CategoricalTabularPolicy(mdp::Union{POMDP,MDP}; rng=Random.GLOBAL_RNG)` - -# Fields -- `stochastic::StochasticPolicy` -- `value::ValuePolicy` -""" -mutable struct CategoricalTabularPolicy <: Policy - stochastic::StochasticPolicy - value::ValuePolicy -end -function CategoricalTabularPolicy(mdp::Union{POMDP,MDP}; rng=Random.GLOBAL_RNG) - CategoricalTabularPolicy(StochasticPolicy(Weights(zeros(length(actions((mdp))))), rng), ValuePolicy(mdp)) -end - -function action(policy::CategoricalTabularPolicy, s) - policy.stochastic.distribution = Weights(policy.value.value_table[stateindex(policy.value.mdp, s),:]) - return policy.value.act[sample(policy.stochastic.rng, policy.stochastic.distribution)] -end diff --git a/src/utility_wrapper.jl b/src/utility_wrapper.jl deleted file mode 100644 index 9b2c4aa..0000000 --- a/src/utility_wrapper.jl +++ /dev/null @@ -1,83 +0,0 @@ -""" - PolicyWrapper - -Flexible utility wrapper for a policy designed for collecting statistics about planning. - -Carries a function, a policy, and optionally a payload (that can be any type). - -The function should typically be defined with the do syntax. Each time `action` is called on the wrapper, this function will be called. - -If there is no payload, it will be called with two argments: the policy and the state/belief. If there is a payload, it will be called with three arguments: the policy, the payload, and the current state or belief. The function should return an appropriate action. The idea is that, in this function, `action(policy, s)` should be called, statistics from the policy/planner should be collected and saved in the payload, exceptions can be handled, and the action should be returned. - -Constructor - -`PolicyWrapper(policy::Policy; payload=nothing)` - -# Example -```julia -using POMDPModels -using POMDPToolbox - -mdp = GridWorld() -policy = RandomPolicy(mdp) -counts = Dict(a=>0 for a in actions(mdp)) - -# with a payload -statswrapper = PolicyWrapper(policy, payload=counts) do policy, counts, s - a = action(policy, s) - counts[a] += 1 - return a -end - -h = simulate(HistoryRecorder(max_steps=100), mdp, statswrapper) -for (a, count) in payload(statswrapper) - println("policy chose action \$a \$count of \$(n_steps(h)) times.") -end - -# without a payload -errwrapper = PolicyWrapper(policy) do policy, s - try - a = action(policy, s) - catch ex - @warn("Caught error in policy; using default") - a = :left - end - return a -end - -h = simulate(HistoryRecorder(max_steps=100), mdp, errwrapper) -``` - -# Fields -- `f::F` -- `policy::P` -- `payload::PL` - -""" -mutable struct PolicyWrapper{P<:Policy, F<:Function, PL} <: Policy - f::F - policy::P - payload::PL -end - -function PolicyWrapper(f::Function, policy::Policy; payload=nothing) - return PolicyWrapper(f, policy, payload) -end - -function PolicyWrapper(policy::Policy; payload=nothing) - return PolicyWrapper((p,s)->action(p.policy,s), policy, payload) -end - -function action(p::PolicyWrapper, s) - if p.payload == nothing - return p.f(p.policy, s) - else - return p.f(p.policy, p.payload, s) - end -end - -updater(p::PolicyWrapper) = updater(p.policy) - -payload(p::PolicyWrapper) = p.payload - -Random.seed!(p::PolicyWrapper, seed) = seed!(p.policy, seed) diff --git a/src/vector.jl b/src/vector.jl deleted file mode 100644 index da04f27..0000000 --- a/src/vector.jl +++ /dev/null @@ -1,72 +0,0 @@ -### Vector Policy ### -# maintained by @zsunberg and @etotheipluspi - -""" - VectorPolicy{S,A} -A generic MDP policy that consists of a vector of actions. The entry at `stateindex(mdp, s)` is the action that will be taken in state `s`. - -# Fields -- `mdp::MDP{S,A}` the MDP problem -- `act::Vector{A}` a vector of size |S| mapping state indices to actions -""" -mutable struct VectorPolicy{S,A} <: Policy - mdp::MDP{S,A} - act::Vector{A} -end - -action(p::VectorPolicy, s) = p.act[stateindex(p.mdp, s)] - -""" - VectorSolver{A} -Solver for VectorPolicy. Doesn't do any computation - just sets the action vector. - -# Fields -- `act::Vector{A}` the action vector -""" -mutable struct VectorSolver{A} - act::Vector{A} -end - -function solve(s::VectorSolver{A}, mdp::MDP{S,A}) where {S,A} - return VectorPolicy{S,A}(mdp, s.act) -end - -function Base.show(io::IO, mime::MIME"text/plain", p::VectorPolicy) - summary(io, p) - println(io, ':') - ds = get(io, :displaysize, displaysize(io)) - ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds))) - showpolicy(ioc, mime, p.mdp, p) -end - -""" - ValuePolicy{P<:Union{POMDP,MDP}, T<:AbstractMatrix{Float64}, A} -A generic MDP policy that consists of a value table. The entry at `stateindex(mdp, s)` is the action that will be taken in state `s`. -It is expected that the order of the actions in the value table is consistent with the order of the actions in `act`. -If `act` is not explicitly set in the construction, `act` is ordered according to `actionindex`. - -# Fields -- `mdp::P` the MDP problem -- `value_table::T` the value table as a |S|x|A| matrix -- `act::Vector{A}` the possible actions -""" -struct ValuePolicy{P<:Union{POMDP,MDP}, T<:AbstractMatrix{Float64}, A} <: Policy - mdp::P - value_table::T - act::Vector{A} -end -function ValuePolicy(mdp::Union{MDP,POMDP}, value_table=zeros(length(states(mdp)), length(actions(mdp)))) - return ValuePolicy(mdp, value_table, ordered_actions(mdp)) -end - -action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])] - -actionvalues(p::ValuePolicy, s) = p.value_table[stateindex(p.mdp, s), :] - -function Base.show(io::IO, mime::MIME"text/plain", p::ValuePolicy{M}) where M <: MDP - summary(io, p) - println(io, ':') - ds = get(io, :displaysize, displaysize(io)) - ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds))) - showpolicy(io, mime, p.mdp, p) -end