From 8703a06a44c36407d04da0eb8a6ed02c9c895315 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 23 May 2019 09:26:11 -0700 Subject: [PATCH] pretty printing won't try to collect all the states --- src/POMDPPolicies.jl | 2 ++ src/pretty_printing.jl | 68 ++++++++++++++---------------------- test/test_pretty_printing.jl | 16 +++++++++ 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/POMDPPolicies.jl b/src/POMDPPolicies.jl index 6458e29..1f2c882 100644 --- a/src/POMDPPolicies.jl +++ b/src/POMDPPolicies.jl @@ -11,6 +11,8 @@ import POMDPs: action, value, solve, updater using BeliefUpdaters using POMDPModelTools +using Base.Iterators # for take + """ actionvalues(p::Policy, s) diff --git a/src/pretty_printing.jl b/src/pretty_printing.jl index 72ecb08..57bd918 100644 --- a/src/pretty_printing.jl +++ b/src/pretty_printing.jl @@ -5,61 +5,47 @@ Print the states in `m` or `statelist` and the actions from policy `p` corresponding to those states. -If `io[:limit]` is `true`, will only print enough states to fill the display. +For the MDP version, if `io[:limit]` is `true`, will only print enough states to fill the display. """ -function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; kwargs...) +function showpolicy(io::IO, mime::MIME"text/plain", m::MDP, p::Policy; pre=" ", kwargs...) slist = nothing + truncated = false + limited = get(io, :limit, false) + rows = first(get(io, :displaysize, displaysize(io))) + rows -= 3 # Yuck! This magic number is also in Base.print_matrix try - slist = ordered_states(m) - catch - try + if limited && n_states(m) > rows + slist = collect(take(states(m), rows-1)) + truncated = true + else slist = collect(states(m)) - catch ex - @info("""Unable to pretty-print policy: - $(sprint(showerror, ex)) - """) - show(io, mime, m) - return show(io, mime, p) end + catch ex + @info("""Unable to pretty-print policy: + $(sprint(showerror, ex)) + """) + show(io, mime, m) + return show(io, mime, p) + end + showpolicy(io, mime, slist, p; pre=pre, kwargs...) + if truncated + print(io, '\n', pre, "…") end - showpolicy(io, mime, slist, p; kwargs...) end function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ") S = eltype(slist) - rows, cols = get(io, :displaysize, displaysize(io)) - rows -= 3 # Yuck! This magic number is also in Base.print_matrix sa_con = IOContext(io, :compact => true) if !isempty(slist) - if get(io, :limit, false) - # print first element without a newline - print(io, pre) - print_sa(sa_con, first(slist), p, S) - - # print middle elements - for s in slist[2:min(rows-1, end)] - print(io, '\n', pre) - print_sa(sa_con, s, p, S) - end - - # print last element or ... - if length(slist) == rows - print(io, '\n', pre) - print_sa(sa_con, last(slist), p, S) - elseif length(slist) > rows - print(io, '\n', pre, "…") - end - else - # print first element without a newline - print(io, pre) - print_sa(sa_con, first(slist), p, S) + # print first element without a newline + print(io, pre) + print_sa(sa_con, first(slist), p, S) - # print all other elements - for s in slist[2:end] - print(io, '\n', pre) - print_sa(sa_con, s, p, S) - end + # print all other elements + for s in slist[2:end] + print(io, '\n', pre) + print_sa(sa_con, s, p, S) end end end diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl index efc82d3..829a043 100644 --- a/test/test_pretty_printing.jl +++ b/test/test_pretty_printing.jl @@ -7,10 +7,26 @@ let p = solve(solver, gw) + # test default @test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left" + # test with small display iob = IOBuffer() io = IOContext(iob, :limit=>true, :displaysize=>(7, 7)) showpolicy(io, gw, p, pre="@ ") @test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …" + + # test very long policy with small display + struct M <: MDP{Int, Int} + n::Int + end + iob = IOBuffer() + io = IOContext(iob, :limit=>true, :displaysize=>(7, 7)) + POMDPs.states(m::MDP) = 1:m.n + POMDPs.n_states(m::MDP) = m.n + POMDPs.actions(m::MDP) = 1:3 + m = M(1_000_000_000) + showpolicy(io, m, RandomPolicy(m)) + # Below, the actual values could be different because of the RandomPolicy, but length should be the same + @test length(String(take!(iob))) == length(" 1 -> 2\n 2 -> 1\n 3 -> 3\n …") end