From cbf9696009b8f1be94d1e30ca595c871f035e636 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 22 May 2019 09:11:02 -0700 Subject: [PATCH] testing, etc. for pretty printing --- src/pretty_printing.jl | 49 ++++++++++++++++-------------------- src/vector.jl | 11 ++++++++ test/test_pretty_printing.jl | 7 +++++- test/test_vector_policy.jl | 4 +++ 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/src/pretty_printing.jl b/src/pretty_printing.jl index bc75bbb..94f7e89 100644 --- a/src/pretty_printing.jl +++ b/src/pretty_printing.jl @@ -1,12 +1,13 @@ """ showpolicy([io], [mime], m::MDP, p::Policy) - showpolicy([io], [mime], statelist::AbstractVector, p::Policy_ + showpolicy([io], [mime], statelist::AbstractVector, p::Policy) + showpolicy(...; pre=" ") Print the states in `m` or `statelist` and the actions from policy `p` corresponding to those states. If `io[:limit]` is `true`, will only print enough states to fill the display. """ -function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy) +function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; kwargs...) slist = nothing try slist = ordered_states(m) @@ -21,37 +22,42 @@ function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy) return show(io, mime, p) end end - showpolicy(io, mime, slist, p) + showpolicy(io, mime, slist, p; kwargs...) end -function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy) +function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ") S = eltype(slist) rows, cols = get(io, :displaysize, displaysize(io)) - ioc = IOContext(io, :compact => true, :displaysize => (1, cols)) + ioc = IOContext(io, :compact => true, :displaysize => (1, cols-length(pre))) if get(io, :limit, false) for s in slist[1:min(rows-1, end)] - print_sa_line(ioc, s, p, S) + print(ioc, pre) + print_sa(ioc, s, p, S) + println(ioc) end if length(slist) == rows - print_sa_line(ioc, last(slist), p, S) - else - println(" …") + print(ioc, pre) + print_sa(ioc, last(slist), p, S) + println(ioc) + elseif length(slist) > rows + println(ioc, pre, "…") end else for s in slist - print_sa_line(ioc, s, p, S) + print(ioc, pre) + print_sa(ioc, s, p, S) + println(ioc) end end end -showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy) = showpolicy(io, MIME("text/plain"), m, p) -showpolicy(m::Union{MDP,AbstractVector}, p::Policy) = showpolicy(stdout, m, p) +showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(io, MIME("text/plain"), m, p; kwargs...) +showpolicy(m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(stdout, m, p; kwargs...) -function print_sa_line(io::IO, s, p::Policy, S::Type) - print(io, ' ') +function print_sa(io::IO, s, p::Policy, S::Type) ds = get(io, :displaysize, displaysize(io)) - half_ds = (first(ds), div(last(ds)-5, 2)) + half_ds = (first(ds), div(last(ds)-4, 2)) show(IOContext(io, :typeinfo => S, :displaysize => half_ds), s) print(io, " -> ") action_io = IOContext(io, :displaysize => half_ds) @@ -60,17 +66,4 @@ function print_sa_line(io::IO, s, p::Policy, S::Type) catch ex showerror(action_io, ex) end - println(io) end - -const PAIRTYPES = Union{Pair{M, P}, Tuple{M, P}} where {M<:MDP, P<:Policy} - -function Base.show(io::IO, mime::MIME"text/plain", pair::PAIRTYPES) - summary(io, pair) - println(io) - remaining = first(displaysize(io)) - 1 - ioc = IOContext(io, :displaysize => displaysize(io)) - showpolicy(ioc, first(pair), last(pair)) -end - -Base.show(io::IO, pair::PAIRTYPES) = show(io, MIME("text/plain"), pair) diff --git a/src/vector.jl b/src/vector.jl index 4ef42a1..6d31bf7 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -31,6 +31,11 @@ function solve(s::VectorSolver{A}, mdp::MDP{S,A}) where {S,A} return VectorPolicy{S,A}(mdp, s.act) end +function Base.show(io::IO, p::VectorPolicy) + summary(io, p) + println(io) + showpolicy(io, p.mdp, p) +end """ ValuePolicy{P<:Union{POMDP,MDP}, T<:AbstractMatrix{Float64}, A} @@ -55,3 +60,9 @@ end action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])] actionvalues(p::ValuePolicy, s) = p.value_table[stateindex(p.mdp, s), :] + +function Base.show(io::IO, p::ValuePolicy{M}) where M <: MDP + summary(io, p) + println(io) + showpolicy(io, p.mdp, p) +end diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl index ebcf2c9..7633bbe 100644 --- a/test/test_pretty_printing.jl +++ b/test/test_pretty_printing.jl @@ -7,5 +7,10 @@ let p = solve(solver, gw) - display(gw=>p) + @test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left\n" + + iob = IOBuffer() + io = IOContext(iob, :limit=>true, :displaysize=>(4, 7)) + showpolicy(io, gw, p, pre="@ ") + @test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …\n" end diff --git a/test/test_vector_policy.jl b/test/test_vector_policy.jl index 1d92fd1..d13cf9e 100644 --- a/test/test_vector_policy.jl +++ b/test/test_vector_policy.jl @@ -7,6 +7,8 @@ let p = solve(solver, gw) + @test string(p) == "VectorPolicy{GridWorldState,Symbol}\n GridWorldState(1, 1, false) -> :left\n GridWorldState(2, 1, false) -> :left\n GridWorldState(1, 2, false) -> :left\n GridWorldState(2, 2, false) -> :left\n GridWorldState(0, 0, true) -> :left\n" + for s1 in states(gw) @test action(p, s1) == GridWorldAction(:left) end @@ -20,4 +22,6 @@ let for s2 in states(gw) @inferred(action(p3, s2)) isa GridWorldAction end + + @test string(p3) == "ValuePolicy{LegacyGridWorld,Array{Float64,2},Symbol}\n GridWorldState(1, 1, false) -> :up\n GridWorldState(2, 1, false) -> :up\n GridWorldState(1, 2, false) -> :up\n GridWorldState(2, 2, false) -> :up\n GridWorldState(0, 0, true) -> :up\n" end