Skip to content

Commit

Permalink
finished some tasks in #9
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed May 22, 2019
1 parent cbf9696 commit 3a0ff6a
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ It currently provides:
- a vector policy type
- a wrapper to collect statistics and errors about policies

In addition, it provides the [`showpolicy`](@ref) function for printing policies similar to the way that matrices are printed in the repl.

```@contents
```
5 changes: 5 additions & 0 deletions docs/src/showpolicy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Pretty Printing Policies

```@docs
showpolicy
```
59 changes: 34 additions & 25 deletions src/pretty_printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,38 @@ end
function showpolicy(io::IO, mime::MIME"text/plain", slist::AbstractVector, p::Policy; pre::AbstractString=" ")
S = eltype(slist)
rows, cols = get(io, :displaysize, displaysize(io))
ioc = IOContext(io, :compact => true, :displaysize => (1, cols-length(pre)))
rows -= 3 # Yuck! This magic number is also in Base.print_matrix
sa_con = IOContext(io, :compact => true)

if get(io, :limit, false)
for s in slist[1:min(rows-1, end)]
print(ioc, pre)
print_sa(ioc, s, p, S)
println(ioc)
end
if length(slist) == rows
print(ioc, pre)
print_sa(ioc, last(slist), p, S)
println(ioc)
elseif length(slist) > rows
println(ioc, pre, "")
end
else
for s in slist
print(ioc, pre)
print_sa(ioc, s, p, S)
println(ioc)
if !isempty(slist)
if get(io, :limit, false)
# print first element without a newline
print(io, pre)
print_sa(sa_con, first(slist), p, S)

# print middle elements
for s in slist[2:min(rows-1, end)]
print(io, '\n', pre)
print_sa(sa_con, s, p, S)
end

# print last element or ...
if length(slist) == rows
print(io, '\n', pre)
print_sa(sa_con, last(slist), p, S)
elseif length(slist) > rows
print(io, '\n', pre, "")
end
else
# print first element without a newline
print(io, pre)
print_sa(sa_con, first(slist), p, S)

# print all other elements
for s in slist[2:end]
print(io, '\n', pre)
print_sa(sa_con, s, p, S)
end
end
end
end
Expand All @@ -56,14 +68,11 @@ showpolicy(io::IO, m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpol
showpolicy(m::Union{MDP,AbstractVector}, p::Policy; kwargs...) = showpolicy(stdout, m, p; kwargs...)

function print_sa(io::IO, s, p::Policy, S::Type)
ds = get(io, :displaysize, displaysize(io))
half_ds = (first(ds), div(last(ds)-4, 2))
show(IOContext(io, :typeinfo => S, :displaysize => half_ds), s)
show(IOContext(io, :typeinfo => S), s)
print(io, " -> ")
action_io = IOContext(io, :displaysize => half_ds)
try
show(action_io, action(p, s))
show(io, action(p, s))
catch ex
showerror(action_io, ex)
showerror(IOContext(io, :limit=>true), ex)
end
end
16 changes: 10 additions & 6 deletions src/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ function solve(s::VectorSolver{A}, mdp::MDP{S,A}) where {S,A}
return VectorPolicy{S,A}(mdp, s.act)
end

function Base.show(io::IO, p::VectorPolicy)
function Base.show(io::IO, mime::MIME"text/plain", p::VectorPolicy)
summary(io, p)
println(io)
showpolicy(io, p.mdp, p)
println(io, ':')
ds = get(io, :displaysize, displaysize(io))
ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds)))
showpolicy(ioc, mime, p.mdp, p)
end

"""
Expand All @@ -61,8 +63,10 @@ action(p::ValuePolicy, s) = p.act[argmax(p.value_table[stateindex(p.mdp, s),:])]

actionvalues(p::ValuePolicy, s) = p.value_table[stateindex(p.mdp, s), :]

function Base.show(io::IO, p::ValuePolicy{M}) where M <: MDP
function Base.show(io::IO, mime::MIME"text/plain", p::ValuePolicy{M}) where M <: MDP
summary(io, p)
println(io)
showpolicy(io, p.mdp, p)
println(io, ':')
ds = get(io, :displaysize, displaysize(io))
ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds)))
showpolicy(io, mime, p.mdp, p)
end
6 changes: 3 additions & 3 deletions test/test_pretty_printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ let

p = solve(solver, gw)

@test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left\n"
@test sprint(showpolicy, gw, p) == " [1, 1] -> :left\n [2, 1] -> :left\n [1, 2] -> :left\n [2, 2] -> :left\n [-1, -1] -> :left"

iob = IOBuffer()
io = IOContext(iob, :limit=>true, :displaysize=>(4, 7))
io = IOContext(iob, :limit=>true, :displaysize=>(7, 7))
showpolicy(io, gw, p, pre="@ ")
@test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …\n"
@test String(take!(iob)) == "@ [1, 1] -> :left\n@ [2, 1] -> :left\n@ [1, 2] -> :left\n@ …"
end
10 changes: 8 additions & 2 deletions test/test_vector_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ let

p = solve(solver, gw)

@test string(p) == "VectorPolicy{GridWorldState,Symbol}\n GridWorldState(1, 1, false) -> :left\n GridWorldState(2, 1, false) -> :left\n GridWorldState(1, 2, false) -> :left\n GridWorldState(2, 2, false) -> :left\n GridWorldState(0, 0, true) -> :left\n"
io = IOBuffer()
d = TextDisplay(io)
display(d, p)
@test String(take!(io)) == "VectorPolicy{GridWorldState,Symbol}:\n GridWorldState(1, 1, false) -> :left\n GridWorldState(2, 1, false) -> :left\n GridWorldState(1, 2, false) -> :left\n GridWorldState(2, 2, false) -> :left\n GridWorldState(0, 0, true) -> :left"

for s1 in states(gw)
@test action(p, s1) == GridWorldAction(:left)
Expand All @@ -23,5 +26,8 @@ let
@inferred(action(p3, s2)) isa GridWorldAction
end

@test string(p3) == "ValuePolicy{LegacyGridWorld,Array{Float64,2},Symbol}\n GridWorldState(1, 1, false) -> :up\n GridWorldState(2, 1, false) -> :up\n GridWorldState(1, 2, false) -> :up\n GridWorldState(2, 2, false) -> :up\n GridWorldState(0, 0, true) -> :up\n"
io = IOBuffer()
d = TextDisplay(io)
display(d, p3)
@test String(take!(io)) == "ValuePolicy{LegacyGridWorld,Array{Float64,2},Symbol}:\n GridWorldState(1, 1, false) -> :up\n GridWorldState(2, 1, false) -> :up\n GridWorldState(1, 2, false) -> :up\n GridWorldState(2, 2, false) -> :up\n GridWorldState(0, 0, true) -> :up"
end

0 comments on commit 3a0ff6a

Please sign in to comment.