diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 1bc2566..816af8a 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,13 +1,19 @@ function baum_welch_has_converged( - logL_evolution::Vector; atol::Real, loglikelihood_increasing::Bool + logL_evolution::Vector; atol::Real, loglikelihood_increasing::Bool, stopping_criterion::Symbol ) if length(logL_evolution) >= 2 logL, logL_prev = logL_evolution[end], logL_evolution[end - 1] progress = logL - logL_prev - if loglikelihood_increasing && progress < min(0, -atol) - error("Loglikelihood decreased in Baum-Welch") - elseif progress < atol - return true + if stopping_criterion == :convergence + if loglikelihood_increasing && progress < min(0, -atol) + error("Loglikelihood decreased in Baum-Welch") + elseif progress < atol + return true + end + elseif stopping_criterion == :stability + if abs(progress) < atol + return true + end end end return false @@ -26,12 +32,13 @@ function baum_welch!( atol::Real, max_iterations::Integer, loglikelihood_increasing::Bool, + stopping_criterion::Symbol ) for _ in 1:max_iterations forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) push!(logL_evolution, logdensityof(hmm) + sum(fb_storage.logL)) fit!(hmm, fb_storage, obs_seq, control_seq; seq_ends) - if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) + if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing, stopping_criterion) break end end @@ -50,6 +57,7 @@ Return a tuple `(hmm_est, loglikelihood_evolution)` where `hmm_est` is the estim - `atol`: minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) - `max_iterations`: maximum number of iterations of the algorithm - `loglikelihood_increasing`: whether to throw an error if the loglikelihood decreases +- `stopping_criterion`: The stopping criterion (either `:convergence` or `:stability`). """ function baum_welch( hmm_guess::AbstractHMM, @@ -59,6 +67,7 @@ function baum_welch( atol=1e-5, max_iterations=100, loglikelihood_increasing=true, + stopping_criterion::Symbol=:convergence ) hmm = deepcopy(hmm_guess) fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) @@ -74,6 +83,7 @@ function baum_welch( atol, max_iterations, loglikelihood_increasing=false, + stopping_criterion=stopping_criterion ) return hmm, logL_evolution end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 6f3f8e5..bb05ca9 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -128,6 +128,8 @@ function obs_logdensities!( @inbounds @simd for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end + logb[findall(i -> i < -log(-nextfloat(-Inf)), logb)] .= -log(-nextfloat(-Inf)) + logb[findall(i -> i > log(prevfloat(Inf)), logb)] .= log(prevfloat(Inf)) @argcheck maximum(logb) < typemax(T) return nothing end