Skip to content

Commit

Permalink
testing, etc. for pretty printing
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed May 22, 2019
1 parent f74e4c0 commit cbf9696
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 29 deletions.
49 changes: 21 additions & 28 deletions src/pretty_printing.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""
showpolicy([io], [mime], m::MDP, p::Policy)
showpolicy([io], [mime], statelist::AbstractVector, p::Policy_
showpolicy([io], [mime], statelist::AbstractVector, p::Policy)
showpolicy(...; pre=" ")
Print the states in `m` or `statelist` and the actions from policy `p` corresponding to those states.
If `io[:limit]` is `true`, will only print enough states to fill the display.
"""
function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy)
function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; kwargs...)
slist = nothing
try
slist = ordered_states(m)
Expand All @@ -21,37 +22,42 @@ function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy)
return show(io, mime, p)
end
end
showpolicy(io, mime, slist, p)
showpolicy(io, mime, slist, p; kwargs...)
end

function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy)
function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ")
S = eltype(slist)
rows, cols = get(io, :displaysize, displaysize(io))
ioc = IOContext(io, :compact => true, :displaysize => (1, cols))
ioc = IOContext(io, :compact => true, :displaysize => (1, cols-length(pre)))

if get(io, :limit, false)
for s in slist[1:min(rows-1, end)]
print_sa_line(ioc, s, p, S)
print(ioc, pre)
print_sa(ioc, s, p, S)
println(ioc)
end
if length(slist) == rows
print_sa_line(ioc, last(slist), p, S)
else
println("")
print(ioc, pre)
print_sa(ioc, last(slist), p, S)
println(ioc)
elseif length(slist) > rows
println(ioc, pre, "")
end
else
for s in slist
print_sa_line(ioc, s, p, S)
print(ioc, pre)
print_sa(ioc, s, p, S)
println(ioc)
end
end
end

showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy) = showpolicy(io, MIME("text/plain"), m, p)
showpolicy(m::Union{MDP,AbstractVector}, p::Policy) = showpolicy(stdout, m, p)
showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(io, MIME("text/plain"), m, p; kwargs...)
showpolicy(m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(stdout, m, p; kwargs...)

function print_sa_line(io::IO, s, p::Policy, S::Type)
print(io, ' ')
function print_sa(io::IO, s, p::Policy, S::Type)
ds = get(io, :displaysize, displaysize(io))
half_ds = (first(ds), div(last(ds)-5, 2))
half_ds = (first(ds), div(last(ds)-4, 2))
show(IOContext(io, :typeinfo => S, :displaysize => half_ds), s)
print(io, " -> ")
action_io = IOContext(io, :displaysize => half_ds)
Expand All @@ -60,17 +66,4 @@ function print_sa_line(io::IO, s, p::Policy, S::Type)
catch ex
showerror(action_io, ex)
end
println(io)
end

const PAIRTYPES = Union{Pair{M, P}, Tuple{M, P}} where {M<:MDP, P<:Policy}

function Base.show(io::IO, mime::MIME"text/plain", pair::PAIRTYPES)
summary(io, pair)
println(io)
remaining = first(displaysize(io)) - 1
ioc = IOContext(io, :displaysize => displaysize(io))
showpolicy(ioc, first(pair), last(pair))
end

Base.show(io::IO, pair::PAIRTYPES) = show(io, MIME("text/plain"), pair)
11 changes: 11 additions & 0 deletions src/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ function solve(s::VectorSolver{A}, mdp::MDP{S,A}) where {S,A}
return VectorPolicy{S,A}(mdp, s.act)
end

function Base.show(io::IO, p::VectorPolicy)
summary(io, p)
println(io)
showpolicy(io, p.mdp, p)
end

"""
ValuePolicy{P<:Union{POMDP,MDP}, T<:AbstractMatrix{Float64}, A}
Expand All @@ -55,3 +60,9 @@ end
action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])]

actionvalues(p::ValuePolicy, s) = p.value_table[stateindex(p.mdp, s), :]

function Base.show(io::IO, p::ValuePolicy{M}) where M <: MDP
summary(io, p)
println(io)
showpolicy(io, p.mdp, p)
end
7 changes: 6 additions & 1 deletion test/test_pretty_printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,10 @@ let

p = solve(solver, gw)

display(gw=>p)
@test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left\n"

iob = IOBuffer()
io = IOContext(iob, :limit=>true, :displaysize=>(4, 7))
showpolicy(io, gw, p, pre="@ ")
@test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …\n"
end
4 changes: 4 additions & 0 deletions test/test_vector_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ let

p = solve(solver, gw)

@test string(p) == "VectorPolicy{GridWorldState,Symbol}\n GridWorldState(1, 1, false) -> :left\n GridWorldState(2, 1, false) -> :left\n GridWorldState(1, 2, false) -> :left\n GridWorldState(2, 2, false) -> :left\n GridWorldState(0, 0, true) -> :left\n"

for s1 in states(gw)
@test action(p, s1) == GridWorldAction(:left)
end
Expand All @@ -20,4 +22,6 @@ let
for s2 in states(gw)
@inferred(action(p3, s2)) isa GridWorldAction
end

@test string(p3) == "ValuePolicy{LegacyGridWorld,Array{Float64,2},Symbol}\n GridWorldState(1, 1, false) -> :up\n GridWorldState(2, 1, false) -> :up\n GridWorldState(1, 2, false) -> :up\n GridWorldState(2, 2, false) -> :up\n GridWorldState(0, 0, true) -> :up\n"
end

0 comments on commit cbf9696

Please sign in to comment.