Skip to content

Commit

Permalink
Improve reporting. (#22)
Browse files Browse the repository at this point in the history
Rewrote reporting, also allowing a time gap between consecutive reports.

Additional minor fixes: wrap included test files into sets, doc fixes.
  • Loading branch information
tpapp authored Oct 7, 2018
1 parent 0122200 commit d7c2797
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 64 deletions.
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ GaussianKE

```@docs
NUTS
AbstractTuner
StepsizeTuner
StepsizeCovTuner
TunerSequence
Expand Down
8 changes: 7 additions & 1 deletion docs/src/lowlevel.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Instead of energies, *negative* energies are used in the code.

The following are used consistently for variables:

- ``: log density we sample from, see [this explanation](@ref ell-tutorial)
- ``: log density we sample from, supports the interface of [LogDensityProblems.AbstractLogDensityProblem](https://github.com/tpapp/LogDensityProblems.jl)
- `κ`: distribution/density that corresponds to kinetic energy
- `H`: Hamiltonian
- `q`: position
Expand Down Expand Up @@ -112,6 +112,12 @@ move
NUTS_transition
```

## Tuning

```@docs
DynamicHMC.AbstractTuner
```

## [Diagnostics](@id diagnostics_lowlevel)

```@docs
Expand Down
120 changes: 75 additions & 45 deletions src/reporting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ export ReportSilent, ReportIO
"""
$(TYPEDEF)
Subtypes implement [`report!`](@ref), [`start_progress!`](@ref), and
[`end_progress!`](@ref).
Subtypes implement
1. [`start_progress!`](@ref), which is used to start a particular iteration,
2. [`report!`](@ref), which triggers the display of progress,
3. [`end_progress!`](@ref) which "frees" the progress report, which can then be reused.
"""
abstract type AbstractReport end

Expand All @@ -15,85 +20,110 @@ struct ReportSilent <: AbstractReport end

report!(::ReportSilent, objects...) = nothing

start_progress!(::ReportSilent, ::Union{Int, Nothing}, ::Any) = nothing
start_progress!(::ReportSilent, ::AbstractString; total_count = nothing) = nothing

end_progress!(::ReportSilent) = nothing

"""
$(TYPEDEF)
Display progress by printing lines to `io` if `countΔ` iterations *and* `time_nsΔ` nanoseconds
have passed since the last display.
$(FIELDS)
"""
mutable struct ReportIO{TIO <: IO} <: AbstractReport
"IO stream for reporting."
io::TIO
color::Union{Symbol, Int}
step_count::Int
total::Union{Int, Nothing}
last_count::Union{Int, Nothing}
last_time::UInt
"Color for report messages."
print_color::Union{Symbol, Int}
"Expected total count. When unknown, set to `nothing`."
total_count::Union{Int, Nothing}
"For comparing current count to the count at the last report. Not binding when negative."
countΔ::Int
"For comparing time to the time at the last report (in ns). Not binding when negative."
time_nsΔ::Int
"Time of starting the process. `nothing` unless start_progress! was called."
start_time_ns::Union{UInt, Nothing}
"Count when a report was last printed. `< 0` before `start_progress!`."
last_printed_count::Int
"Time (in ns) when a report was last printed."
last_printed_time_ns::UInt
end

"""
$SIGNATURES
Report to the given stream `io` (defaults to `stderr`).
For progress bars, emit new information every after `step_count` steps.
`color` is used with `print_with_color`.
See the documentation of the type for keyword arguments.
"""
ReportIO(; io::IO = stderr, color = :blue, step_count = 100) =
ReportIO(io, color, step_count, nothing, nothing, zero(UInt))
function ReportIO(; io::IO = stderr, print_color = :blue,
total_count::Union{Nothing, Integer} = nothing,
countΔ::Integer = total_count isa Integer ? total_count ÷ 10 : 100,
time_nsΔ::Integer = 10^9)
countΔ 0 && time_nsΔ 0 && @warn "progress report will be printed for every step"
ReportIO(io, print_color, total_count, countΔ, time_nsΔ, nothing, 0, time_ns())
end

"""
$SIGNATURES
Start a progress meter for an iteration. The second argument is either
Start a progress meter for an iteration.
- `nothing`, if the total number of steps is unknown,
`total_count` can be overwritten by a keyword argument.
- an integer, for the total number of steps.
After calling this function, [`report!`](@ref) should be used at every step with
an integer.
After calling this function, [`report!`](@ref) should be used at every step with an integer.
"""
function start_progress!(report::ReportIO, total, msg)
if total isa Integer
msg *= " ($(total) steps)"
end
printstyled(report.io, msg, '\n'; color = report.color, bold = true)
report.total = total
report.last_count = 0
report.last_time = time_ns()
function start_progress!(report::ReportIO, msg; total_count = report.total_count)
@unpack io, print_color = report
@argcheck report.start_time_ns nothing "end_progress was not called"
totalmsg = total_count nothing ? "unknown number of" : total_count
msg *= " ($(totalmsg) steps)"
printstyled(io, msg, '\n'; color = print_color, bold = true)
report.total_count = total_count
report.start_time_ns = report.last_printed_time_ns = time_ns()
report.last_printed_count = 0
nothing
end

function _report_avg_msg(report::ReportIO, count, _time_ns)
s_per_iteration = (_time_ns - report.start_time_ns) / count / 1_000_000_000
"$(round(s_per_iteration; sigdigits = 2)) s/step"
end

"""
$SIGNATURES
Terminate a progress meter.
"""
function end_progress!(report::ReportIO)
printstyled(report.io, " ...done\n"; bold = true, color = report.color)
report.last_count = nothing
function end_progress!(report::ReportIO, count::Integer = report.total_count::Integer)
avgmsg = _report_avg_msg(report, count, time_ns())
printstyled(report.io, "$(avgmsg) ...done\n"; bold = true, color = report.print_color)
report.start_time_ns = nothing
end

"""
$SIGNATURES
Display `objects` via the appropriate mechanism.
When a single `Int` is given, it is treated as the index of the current step.
Display `report` via the appropriate mechanism. `count` is the index of the current step.
"""
function report!(report::ReportIO, count::Int)
@unpack io, step_count, color, total = report
@argcheck report.last_count isa Int "start_progress! was not called."
if count % step_count == 0
function report!(report::ReportIO, count::Integer)
@unpack io, countΔ, time_nsΔ, start_time_ns, last_printed_count, last_printed_time_ns,
total_count = report
@argcheck start_time_ns nothing "start_progress! was not called."
_time_ns = time_ns()
ispastcount = countΔ count - last_printed_count
ispasttime = time_nsΔ _time_ns - last_printed_time_ns
if ispastcount && ispasttime
msg = "step $(count)"
if total isa Int
msg *= "/$(total)"
if total_count isa Int
msg *= " (of $(total_count))"
end
t = time_ns()
s_per_iteration = (t - report.last_time) / step_count / 1000
msg *= ", $(round(s_per_iteration; sigdigits = 2)) s/step"
printstyled(io, msg, '\n'; color = color)
report.last_time = t
report.last_count = count
msg *= ", " * _report_avg_msg(report, count, _time_ns)
printstyled(io, msg, '\n'; color = report.print_color)
report.last_printed_time_ns = _time_ns
report.last_printed_count = count
end
nothing
end
4 changes: 2 additions & 2 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ which has elements that conform to the sampler.
function mcmc(sampler::NUTS{Tv,Tf}, N::Int) where {Tv,Tf}
@unpack rng, H, q, ϵ, max_depth, report = sampler
sample = Vector{NUTS_Transition{Tv,Tf}}(undef, N)
start_progress!(report, N, "MCMC")
start_progress!(report, "MCMC"; total_count = N)
for i in 1:N
trans = NUTS_transition(rng, H, q, ϵ, max_depth)
q = trans.q
Expand All @@ -63,7 +63,7 @@ When the last two parameters are not specified, initialize using `adapting_ϵ`.
function mcmc_adapting_ϵ(sampler::NUTS{Tv,Tf}, N::Int, A_params, A) where {Tv,Tf}
@unpack rng, H, q, max_depth, report = sampler
sample = Vector{NUTS_Transition{Tv,Tf}}(undef, N)
start_progress!(report, N, "MCMC, adapting ϵ")
start_progress!(report, "MCMC, adapting ϵ"; total_count = N)
for i in 1:N
trans = NUTS_transition(rng, H, q, get_ϵ(A), max_depth)
A = adapt_stepsize(A_params, A, trans.a)
Expand Down
27 changes: 18 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,23 @@ function rand_Hz(K)
H, z
end

include("test-Hamiltonian-leapfrog.jl")
include("test-buildingblocks.jl")
include("test-stepsize.jl")
include("test-sample-dummy.jl")
include("test-tuners.jl")
include("test-sample-normal.jl")
include("test-normal-mcmc.jl")
include("test-statistics.jl")
include("test-reporting.jl")
macro include_testset(filename)
@assert filename isa AbstractString
quote
@testset $(filename) begin
include($(filename))
end
end
end

@include_testset("test-Hamiltonian-leapfrog.jl")
@include_testset("test-buildingblocks.jl")
@include_testset("test-stepsize.jl")
@include_testset("test-sample-dummy.jl")
@include_testset("test-tuners.jl")
@include_testset("test-sample-normal.jl")
@include_testset("test-normal-mcmc.jl")
@include_testset("test-statistics.jl")
@include_testset("test-reporting.jl")

include("../docs/make.jl")
14 changes: 8 additions & 6 deletions test/test-reporting.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
@testset "reporting" begin
= DistributionLogDensity(MvNormal, 3)
@color_output false begin
output = @capture_err begin
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, 1000; report = ReportIO())
output = @color_output false begin
@capture_err begin
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, 1000;
report = ReportIO(; countΔ = 100, time_nsΔ = -1))
end
end
float_regex = raw"(0|[1-9][0-9]*)(\.[0-9]+)?([eE][+-]?[0-9]+)?"
function expectedA(msg, n)
r = "$msg \\($(n) steps\\)\\n"
r = "$(msg) \\($(n) steps\\)\\n"
for i in 100:100:n
r *= "step $(i)/$(n), \\d+\\.\\d+ s/step\\n"
r *= "step $(i) \\(of $(n)\\), $(float_regex) s/step\\n"
end
r *= " \\.\\.\\.done\\n"
r *= "$(float_regex) s/step \\.\\.\\.done\\n"
end
raw_regex = join(expectedA.(vcat(fill("MCMC, adapting ϵ", 7), ["MCMC"]),
[75, 25, 50, 100, 200, 400, 50, 1000]), "")
Expand Down

0 comments on commit d7c2797

Please sign in to comment.