diff --git a/docs/src/api.md b/docs/src/api.md index abc78bb..12bdf5c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -55,7 +55,6 @@ GaussianKE ```@docs NUTS -AbstractTuner StepsizeTuner StepsizeCovTuner TunerSequence diff --git a/docs/src/lowlevel.md b/docs/src/lowlevel.md index b6451ae..e2c0d3e 100644 --- a/docs/src/lowlevel.md +++ b/docs/src/lowlevel.md @@ -10,7 +10,7 @@ Instead of energies, *negative* energies are used in the code. The following are used consistently for variables: -- `ℓ`: log density we sample from, see [this explanation](@ref ell-tutorial) +- `ℓ`: log density we sample from, supports the interface of [LogDensityProblems.AbstractLogDensityProblem](https://github.com/tpapp/LogDensityProblems.jl) - `κ`: distribution/density that corresponds to kinetic energy - `H`: Hamiltonian - `q`: position @@ -112,6 +112,12 @@ move NUTS_transition ``` +## Tuning + +```@docs +DynamicHMC.AbstractTuner +``` + ## [Diagnostics](@id diagnostics_lowlevel) ```@docs diff --git a/src/reporting.jl b/src/reporting.jl index 2802fcc..69744b3 100644 --- a/src/reporting.jl +++ b/src/reporting.jl @@ -3,8 +3,13 @@ export ReportSilent, ReportIO """ $(TYPEDEF) -Subtypes implement [`report!`](@ref), [`start_progress!`](@ref), and -[`end_progress!`](@ref). +Subtypes implement + +1. [`start_progress!`](@ref), which is used to start a particular iteration, + +2. [`report!`](@ref), which triggers the display of progress, + +3. [`end_progress!`](@ref) which "frees" the progress report, which can then be reused. """ abstract type AbstractReport end @@ -15,17 +20,35 @@ struct ReportSilent <: AbstractReport end report!(::ReportSilent, objects...) = nothing -start_progress!(::ReportSilent, ::Union{Int, Nothing}, ::Any) = nothing +start_progress!(::ReportSilent, ::AbstractString; total_count = nothing) = nothing end_progress!(::ReportSilent) = nothing +""" +$(TYPEDEF) + +Display progress by printing lines to `io` if `countΔ` iterations *and* `time_nsΔ` nanoseconds +have passed since the last display. + +$(FIELDS) +""" mutable struct ReportIO{TIO <: IO} <: AbstractReport + "IO stream for reporting." io::TIO - color::Union{Symbol, Int} - step_count::Int - total::Union{Int, Nothing} - last_count::Union{Int, Nothing} - last_time::UInt + "Color for report messages." + print_color::Union{Symbol, Int} + "Expected total count. When unknown, set to `nothing`." + total_count::Union{Int, Nothing} + "For comparing current count to the count at the last report. Not binding when negative." + countΔ::Int + "For comparing time to the time at the last report (in ns). Not binding when negative." + time_nsΔ::Int + "Time of starting the process. `nothing` unless start_progress! was called." + start_time_ns::Union{UInt, Nothing} + "Count when a report was last printed. `< 0` before `start_progress!`." + last_printed_count::Int + "Time (in ns) when a report was last printed." + last_printed_time_ns::UInt end """ @@ -33,67 +56,74 @@ end Report to the given stream `io` (defaults to `stderr`). -For progress bars, emit new information every after `step_count` steps. - -`color` is used with `print_with_color`. +See the documentation of the type for keyword arguments. """ -ReportIO(; io::IO = stderr, color = :blue, step_count = 100) = - ReportIO(io, color, step_count, nothing, nothing, zero(UInt)) +function ReportIO(; io::IO = stderr, print_color = :blue, + total_count::Union{Nothing, Integer} = nothing, + countΔ::Integer = total_count isa Integer ? total_count ÷ 10 : 100, + time_nsΔ::Integer = 10^9) + countΔ ≤ 0 && time_nsΔ ≤ 0 && @warn "progress report will be printed for every step" + ReportIO(io, print_color, total_count, countΔ, time_nsΔ, nothing, 0, time_ns()) +end """ $SIGNATURES -Start a progress meter for an iteration. The second argument is either +Start a progress meter for an iteration. -- `nothing`, if the total number of steps is unknown, +`total_count` can be overwritten by a keyword argument. -- an integer, for the total number of steps. - -After calling this function, [`report!`](@ref) should be used at every step with -an integer. +After calling this function, [`report!`](@ref) should be used at every step with an integer. """ -function start_progress!(report::ReportIO, total, msg) - if total isa Integer - msg *= " ($(total) steps)" - end - printstyled(report.io, msg, '\n'; color = report.color, bold = true) - report.total = total - report.last_count = 0 - report.last_time = time_ns() +function start_progress!(report::ReportIO, msg; total_count = report.total_count) + @unpack io, print_color = report + @argcheck report.start_time_ns ≡ nothing "end_progress was not called" + totalmsg = total_count ≡ nothing ? "unknown number of" : total_count + msg *= " ($(totalmsg) steps)" + printstyled(io, msg, '\n'; color = print_color, bold = true) + report.total_count = total_count + report.start_time_ns = report.last_printed_time_ns = time_ns() + report.last_printed_count = 0 nothing end +function _report_avg_msg(report::ReportIO, count, _time_ns) + s_per_iteration = (_time_ns - report.start_time_ns) / count / 1_000_000_000 + "$(round(s_per_iteration; sigdigits = 2)) s/step" +end + """ $SIGNATURES Terminate a progress meter. """ -function end_progress!(report::ReportIO) - printstyled(report.io, " ...done\n"; bold = true, color = report.color) - report.last_count = nothing +function end_progress!(report::ReportIO, count::Integer = report.total_count::Integer) + avgmsg = _report_avg_msg(report, count, time_ns()) + printstyled(report.io, "$(avgmsg) ...done\n"; bold = true, color = report.print_color) + report.start_time_ns = nothing end """ $SIGNATURES -Display `objects` via the appropriate mechanism. - -When a single `Int` is given, it is treated as the index of the current step. +Display `report` via the appropriate mechanism. `count` is the index of the current step. """ -function report!(report::ReportIO, count::Int) - @unpack io, step_count, color, total = report - @argcheck report.last_count isa Int "start_progress! was not called." - if count % step_count == 0 +function report!(report::ReportIO, count::Integer) + @unpack io, countΔ, time_nsΔ, start_time_ns, last_printed_count, last_printed_time_ns, + total_count = report + @argcheck start_time_ns ≠ nothing "start_progress! was not called." + _time_ns = time_ns() + ispastcount = countΔ ≤ count - last_printed_count + ispasttime = time_nsΔ ≤ _time_ns - last_printed_time_ns + if ispastcount && ispasttime msg = "step $(count)" - if total isa Int - msg *= "/$(total)" + if total_count isa Int + msg *= " (of $(total_count))" end - t = time_ns() - s_per_iteration = (t - report.last_time) / step_count / 1000 - msg *= ", $(round(s_per_iteration; sigdigits = 2)) s/step" - printstyled(io, msg, '\n'; color = color) - report.last_time = t - report.last_count = count + msg *= ", " * _report_avg_msg(report, count, _time_ns) + printstyled(io, msg, '\n'; color = report.print_color) + report.last_printed_time_ns = _time_ns + report.last_printed_count = count end nothing end diff --git a/src/sampler.jl b/src/sampler.jl index 28da1ee..808b233 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -40,7 +40,7 @@ which has elements that conform to the sampler. function mcmc(sampler::NUTS{Tv,Tf}, N::Int) where {Tv,Tf} @unpack rng, H, q, ϵ, max_depth, report = sampler sample = Vector{NUTS_Transition{Tv,Tf}}(undef, N) - start_progress!(report, N, "MCMC") + start_progress!(report, "MCMC"; total_count = N) for i in 1:N trans = NUTS_transition(rng, H, q, ϵ, max_depth) q = trans.q @@ -63,7 +63,7 @@ When the last two parameters are not specified, initialize using `adapting_ϵ`. function mcmc_adapting_ϵ(sampler::NUTS{Tv,Tf}, N::Int, A_params, A) where {Tv,Tf} @unpack rng, H, q, max_depth, report = sampler sample = Vector{NUTS_Transition{Tv,Tf}}(undef, N) - start_progress!(report, N, "MCMC, adapting ϵ") + start_progress!(report, "MCMC, adapting ϵ"; total_count = N) for i in 1:N trans = NUTS_transition(rng, H, q, get_ϵ(A), max_depth) A = adapt_stepsize(A_params, A, trans.a) diff --git a/test/runtests.jl b/test/runtests.jl index 74acd94..041ac6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -111,14 +111,23 @@ function rand_Hz(K) H, z end -include("test-Hamiltonian-leapfrog.jl") -include("test-buildingblocks.jl") -include("test-stepsize.jl") -include("test-sample-dummy.jl") -include("test-tuners.jl") -include("test-sample-normal.jl") -include("test-normal-mcmc.jl") -include("test-statistics.jl") -include("test-reporting.jl") +macro include_testset(filename) + @assert filename isa AbstractString + quote + @testset $(filename) begin + include($(filename)) + end + end +end + +@include_testset("test-Hamiltonian-leapfrog.jl") +@include_testset("test-buildingblocks.jl") +@include_testset("test-stepsize.jl") +@include_testset("test-sample-dummy.jl") +@include_testset("test-tuners.jl") +@include_testset("test-sample-normal.jl") +@include_testset("test-normal-mcmc.jl") +@include_testset("test-statistics.jl") +@include_testset("test-reporting.jl") include("../docs/make.jl") diff --git a/test/test-reporting.jl b/test/test-reporting.jl index fbb0975..dcd9c74 100644 --- a/test/test-reporting.jl +++ b/test/test-reporting.jl @@ -1,16 +1,18 @@ @testset "reporting" begin ℓ = DistributionLogDensity(MvNormal, 3) - @color_output false begin - output = @capture_err begin - sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, 1000; report = ReportIO()) + output = @color_output false begin + @capture_err begin + sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, 1000; + report = ReportIO(; countΔ = 100, time_nsΔ = -1)) end end + float_regex = raw"(0|[1-9][0-9]*)(\.[0-9]+)?([eE][+-]?[0-9]+)?" function expectedA(msg, n) - r = "$msg \\($(n) steps\\)\\n" + r = "$(msg) \\($(n) steps\\)\\n" for i in 100:100:n - r *= "step $(i)/$(n), \\d+\\.\\d+ s/step\\n" + r *= "step $(i) \\(of $(n)\\), $(float_regex) s/step\\n" end - r *= " \\.\\.\\.done\\n" + r *= "$(float_regex) s/step \\.\\.\\.done\\n" end raw_regex = join(expectedA.(vcat(fill("MCMC, adapting ϵ", 7), ["MCMC"]), [75, 25, 50, 100, 200, 400, 50, 1000]), "")