Skip to content

Commit

Permalink
Merge pull request #125 from Juice-jl/callback
Browse files Browse the repository at this point in the history
early stopping as a call back function
  • Loading branch information
MhDang authored Oct 13, 2022
2 parents fe10122 + 7c8cdb3 commit ae94c1d
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/parameters/em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,4 +569,68 @@ end
cleanup(caller::CALLBACK) = nothing
cleanup(caller::LikelihoodsLog) = begin
CUDA.unsafe_free!(caller.mars_mem)
end

# early stopping
mutable struct EarlyStopPC <: CALLBACK
likelihoods_log
patience
warmup
val

best_value
best_iter
best_bpc

n_increase
iter
EarlyStopPC(likelihoods_log; patience, warmup=1, val=:valid_x) = begin
@assert val == :valid_x
@assert patience % likelihoods_log.iter == 0
@assert !isnothing(likelihoods_log.valid_x)
new(likelihoods_log, Int(ceil(patience / likelihoods_log.iter)),
warmup, val, -Inf, 0, nothing, 0, 0)
end
end

init(caller::EarlyStopPC; args...) = begin
init(caller.likelihoods_log; args...)
bpc = caller.likelihoods_log.bpc
best_bpc = (edge_layers_up = deepcopy(bpc.edge_layers_up), heap = deepcopy(bpc.heap))
caller.best_bpc = best_bpc
end

call(caller::EarlyStopPC, epoch, log_likelihood) = begin
valid_ll, test_ll = call(caller.likelihoods_log, epoch, log_likelihood)
caller.iter += 1
flag = false
if isnothing(valid_ll) || caller.iter < caller.warmup
flag = false
elseif valid_ll >= caller.best_value
caller.n_increase = 0
caller.best_value = valid_ll
caller.best_iter = epoch
copy_bpc!(caller.best_bpc, caller.likelihoods_log.bpc)
flag = false
elseif valid_ll < caller.best_value
caller.n_increase += 1
if caller.n_increase > caller.patience
copy_bpc!(caller.likelihoods_log.bpc, caller.best_bpc)
flag = true
else
flag = false
end
else
error("")
end
return flag
end

copy_bpc!(dst, src) = begin
copyto!(dst.edge_layers_up.vectors, src.edge_layers_up.vectors)
copyto!(dst.heap, src.heap)
end

cleanup(caller::EarlyStopPC) = begin
cleanup(caller.likelihoods_log)
end

0 comments on commit ae94c1d

Please sign in to comment.