diff --git a/docs/src/index.md b/docs/src/index.md index b2434c7..dab1062 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,6 +9,7 @@ It currently provides: - a vector policy type - a wrapper to collect statistics and errors about policies +In addition, it provides the [`showpolicy`](@ref) function for printing policies similar to the way that matrices are printed in the repl. ```@contents ``` diff --git a/docs/src/showpolicy.md b/docs/src/showpolicy.md new file mode 100644 index 0000000..35346c9 --- /dev/null +++ b/docs/src/showpolicy.md @@ -0,0 +1,5 @@ +# Pretty Printing Policies + +```@docs +showpolicy +``` diff --git a/src/pretty_printing.jl b/src/pretty_printing.jl index 94f7e89..72ecb08 100644 --- a/src/pretty_printing.jl +++ b/src/pretty_printing.jl @@ -28,26 +28,38 @@ end function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ") S = eltype(slist) rows, cols = get(io, :displaysize, displaysize(io)) - ioc = IOContext(io, :compact => true, :displaysize => (1, cols-length(pre))) + rows -= 3 # Yuck! This magic number is also in Base.print_matrix + sa_con = IOContext(io, :compact => true) - if get(io, :limit, false) - for s in slist[1:min(rows-1, end)] - print(ioc, pre) - print_sa(ioc, s, p, S) - println(ioc) - end - if length(slist) == rows - print(ioc, pre) - print_sa(ioc, last(slist), p, S) - println(ioc) - elseif length(slist) > rows - println(ioc, pre, "…") - end - else - for s in slist - print(ioc, pre) - print_sa(ioc, s, p, S) - println(ioc) + if !isempty(slist) + if get(io, :limit, false) + # print first element without a newline + print(io, pre) + print_sa(sa_con, first(slist), p, S) + + # print middle elements + for s in slist[2:min(rows-1, end)] + print(io, '\n', pre) + print_sa(sa_con, s, p, S) + end + + # print last element or ... + if length(slist) == rows + print(io, '\n', pre) + print_sa(sa_con, last(slist), p, S) + elseif length(slist) > rows + print(io, '\n', pre, "…") + end + else + # print first element without a newline + print(io, pre) + print_sa(sa_con, first(slist), p, S) + + # print all other elements + for s in slist[2:end] + print(io, '\n', pre) + print_sa(sa_con, s, p, S) + end end end end @@ -56,14 +68,11 @@ showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpol showpolicy(m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(stdout, m, p; kwargs...) function print_sa(io::IO, s, p::Policy, S::Type) - ds = get(io, :displaysize, displaysize(io)) - half_ds = (first(ds), div(last(ds)-4, 2)) - show(IOContext(io, :typeinfo => S, :displaysize => half_ds), s) + show(IOContext(io, :typeinfo => S), s) print(io, " -> ") - action_io = IOContext(io, :displaysize => half_ds) try - show(action_io, action(p, s)) + show(io, action(p, s)) catch ex - showerror(action_io, ex) + showerror(IOContext(io, :limit=>true), ex) end end diff --git a/src/vector.jl b/src/vector.jl index 6d31bf7..a08e463 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -31,10 +31,12 @@ function solve(s::VectorSolver{A}, mdp::MDP{S,A}) where {S,A} return VectorPolicy{S,A}(mdp, s.act) end -function Base.show(io::IO, p::VectorPolicy) +function Base.show(io::IO, mime::MIME"text/plain", p::VectorPolicy) summary(io, p) - println(io) - showpolicy(io, p.mdp, p) + println(io, ':') + ds = get(io, :displaysize, displaysize(io)) + ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds))) + showpolicy(ioc, mime, p.mdp, p) end """ @@ -61,8 +63,10 @@ action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])] actionvalues(p::ValuePolicy, s) = p.value_table[stateindex(p.mdp, s), :] -function Base.show(io::IO, p::ValuePolicy{M}) where M <: MDP +function Base.show(io::IO, mime::MIME"text/plain", p::ValuePolicy{M}) where M <: MDP summary(io, p) - println(io) - showpolicy(io, p.mdp, p) + println(io, ':') + ds = get(io, :displaysize, displaysize(io)) + ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds))) + showpolicy(io, mime, p.mdp, p) end diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl index 7633bbe..efc82d3 100644 --- a/test/test_pretty_printing.jl +++ b/test/test_pretty_printing.jl @@ -7,10 +7,10 @@ let p = solve(solver, gw) - @test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left\n" + @test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left" iob = IOBuffer() - io = IOContext(iob, :limit=>true, :displaysize=>(4, 7)) + io = IOContext(iob, :limit=>true, :displaysize=>(7, 7)) showpolicy(io, gw, p, pre="@ ") - @test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …\n" + @test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …" end diff --git a/test/test_vector_policy.jl b/test/test_vector_policy.jl index d13cf9e..8df0705 100644 --- a/test/test_vector_policy.jl +++ b/test/test_vector_policy.jl @@ -7,7 +7,10 @@ let p = solve(solver, gw) - @test string(p) == "VectorPolicy{GridWorldState,Symbol}\n GridWorldState(1, 1, false) -> :left\n GridWorldState(2, 1, false) -> :left\n GridWorldState(1, 2, false) -> :left\n GridWorldState(2, 2, false) -> :left\n GridWorldState(0, 0, true) -> :left\n" + io = IOBuffer() + d = TextDisplay(io) + display(d, p) + @test String(take!(io)) == "VectorPolicy{GridWorldState,Symbol}:\n GridWorldState(1, 1, false) -> :left\n GridWorldState(2, 1, false) -> :left\n GridWorldState(1, 2, false) -> :left\n GridWorldState(2, 2, false) -> :left\n GridWorldState(0, 0, true) -> :left" for s1 in states(gw) @test action(p, s1) == GridWorldAction(:left) @@ -23,5 +26,8 @@ let @inferred(action(p3, s2)) isa GridWorldAction end - @test string(p3) == "ValuePolicy{LegacyGridWorld,Array{Float64,2},Symbol}\n GridWorldState(1, 1, false) -> :up\n GridWorldState(2, 1, false) -> :up\n GridWorldState(1, 2, false) -> :up\n GridWorldState(2, 2, false) -> :up\n GridWorldState(0, 0, true) -> :up\n" + io = IOBuffer() + d = TextDisplay(io) + display(d, p3) + @test String(take!(io)) == "ValuePolicy{LegacyGridWorld,Array{Float64,2},Symbol}:\n GridWorldState(1, 1, false) -> :up\n GridWorldState(2, 1, false) -> :up\n GridWorldState(1, 2, false) -> :up\n GridWorldState(2, 2, false) -> :up\n GridWorldState(0, 0, true) -> :up" end