Skip to content

Commit

Permalink
Merge pull request #292 from MineralsCloud:isconvergent
Browse files Browse the repository at this point in the history
Fix `isconvergent` & `TestConvergence` using `StaticConfig.threshold`
  • Loading branch information
singularitti authored Oct 15, 2023
2 parents 885673d + 8fb5e56 commit dafdad7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
11 changes: 8 additions & 3 deletions src/ConvergenceTestWorkflow/Config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ end
recipe::String
template::String
with::Union{CutoffEnergies,MonkhorstPackGrids}
criteria::Quantity
threshold::Quantity
io::IO = IO()
data::Data = Data()
cli::SoftwareConfig
function StaticConfig(recipe, template, with, criteria, io, data, cli)
function StaticConfig(recipe, template, with, threshold, io, data, cli)
@assert recipe in ("ecut", "kmesh")
if !isfile(template)
@warn "I cannot find template file `$template`!"
end
return new(recipe, template, with, criteria, io, data, cli)
return new(recipe, template, with, threshold, io, data, cli)
end
end

Expand Down Expand Up @@ -80,6 +80,10 @@ function _update!(conf::Conf, data::Data)
conf.data.raw = abspath(expanduser(data.raw))
return conf
end
function _update!(conf::Conf, threshold::Quantity)
conf.threshold = threshold
return conf
end

function expand(config::StaticConfig, calculation::Calculation)
conf = Conf()
Expand All @@ -89,6 +93,7 @@ function expand(config::StaticConfig, calculation::Calculation)
_update!(conf, config.with)
_update!(conf, config.io, config.with)
_update!(conf, config.data)
_update!(conf, config.threshold)
return conf
end

Expand Down
11 changes: 6 additions & 5 deletions src/ConvergenceTestWorkflow/actions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ end
struct TestConvergence{T} <: Action{T}
calculation::T
end
(::TestConvergence)(data) = isconvergent(data)
(::TestConvergence)(threshold) = Base.Fix2(isconvergent, threshold)

function isconvergent(a)
terms = abs.(diff(collect(a)))
x, y, z = last(terms, 3)
return all(0 <= r < 1 for r in (y / x, z / y))
function isconvergent(iter, threshold)
last3 = last(sort(iter; by=first), 3) # Sort a `Set` of `Pair`s by the keys, i.e., increasing cutoff energies
min, max = extrema(last.(last3)) # Get the minimum and maximum energies from the last 3 pairs
range = abs(max - min)
return zero(range) <= range <= threshold
end
2 changes: 1 addition & 1 deletion src/ConvergenceTestWorkflow/think.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ think(action::CreateInput, conf::Conf) =
think(action::ExtractData, conf::Conf) =
collect(Thunk(action, file) for file in last.(conf.io))
think(action::SaveData, conf::Conf) = Thunk(action(conf.data.raw))
think(action::TestConvergence, ::Conf) = Thunk(action)
think(action::TestConvergence, conf::Conf) = Thunk(action, conf.threshold)
function think(action::Action{T}, config::StaticConfig) where {T}
config = expand(config, T())
return think(action, config::Conf)
Expand Down

0 comments on commit dafdad7

Please sign in to comment.