Skip to content

Commit

Permalink
Merge pull request #10 from JuliaPOMDP/dont_collect
Browse files Browse the repository at this point in the history
pretty printing won't try to collect all the states
  • Loading branch information
zsunberg authored May 23, 2019
2 parents 3a0ff6a + 8703a06 commit 185c145
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 41 deletions.
2 changes: 2 additions & 0 deletions src/POMDPPolicies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import POMDPs: action, value, solve, updater
using BeliefUpdaters
using POMDPModelTools

using Base.Iterators # for take

"""
actionvalues(p::Policy, s)
Expand Down
68 changes: 27 additions & 41 deletions src/pretty_printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,47 @@
Print the states in `m` or `statelist` and the actions from policy `p` corresponding to those states.
If `io[:limit]` is `true`, will only print enough states to fill the display.
For the MDP version, if `io[:limit]` is `true`, will only print enough states to fill the display.
"""
function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; kwargs...)
function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; pre=" ", kwargs...)
slist = nothing
truncated = false
limited = get(io, :limit, false)
rows = first(get(io, :displaysize, displaysize(io)))
rows -= 3 # Yuck! This magic number is also in Base.print_matrix
try
slist = ordered_states(m)
catch
try
if limited && n_states(m) > rows
slist = collect(take(states(m), rows-1))
truncated = true
else
slist = collect(states(m))
catch ex
@info("""Unable to pretty-print policy:
$(sprint(showerror, ex))
""")
show(io, mime, m)
return show(io, mime, p)
end
catch ex
@info("""Unable to pretty-print policy:
$(sprint(showerror, ex))
""")
show(io, mime, m)
return show(io, mime, p)
end
showpolicy(io, mime, slist, p; pre=pre, kwargs...)
if truncated
print(io, '\n', pre, "")
end
showpolicy(io, mime, slist, p; kwargs...)
end

function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ")
S = eltype(slist)
rows, cols = get(io, :displaysize, displaysize(io))
rows -= 3 # Yuck! This magic number is also in Base.print_matrix
sa_con = IOContext(io, :compact => true)

if !isempty(slist)
if get(io, :limit, false)
# print first element without a newline
print(io, pre)
print_sa(sa_con, first(slist), p, S)

# print middle elements
for s in slist[2:min(rows-1, end)]
print(io, '\n', pre)
print_sa(sa_con, s, p, S)
end

# print last element or ...
if length(slist) == rows
print(io, '\n', pre)
print_sa(sa_con, last(slist), p, S)
elseif length(slist) > rows
print(io, '\n', pre, "")
end
else
# print first element without a newline
print(io, pre)
print_sa(sa_con, first(slist), p, S)
# print first element without a newline
print(io, pre)
print_sa(sa_con, first(slist), p, S)

# print all other elements
for s in slist[2:end]
print(io, '\n', pre)
print_sa(sa_con, s, p, S)
end
# print all other elements
for s in slist[2:end]
print(io, '\n', pre)
print_sa(sa_con, s, p, S)
end
end
end
Expand Down
16 changes: 16 additions & 0 deletions test/test_pretty_printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,26 @@ let

p = solve(solver, gw)

# test default
@test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left"

# test with small display
iob = IOBuffer()
io = IOContext(iob, :limit=>true, :displaysize=>(7, 7))
showpolicy(io, gw, p, pre="@ ")
@test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …"

# test very long policy with small display
struct M <: MDP{Int, Int}
n::Int
end
iob = IOBuffer()
io = IOContext(iob, :limit=>true, :displaysize=>(7, 7))
POMDPs.states(m::MDP) = 1:m.n
POMDPs.n_states(m::MDP) = m.n
POMDPs.actions(m::MDP) = 1:3
m = M(1_000_000_000)
showpolicy(io, m, RandomPolicy(m))
# Below, the actual values could be different because of the RandomPolicy, but length should be the same
@test length(String(take!(iob))) == length(" 1 -> 2\n 2 -> 1\n 3 -> 3\n")
end

0 comments on commit 185c145

Please sign in to comment.