Skip to content

Commit

Permalink
Merge pull request #9 from gdalle/update_hmms
Browse files Browse the repository at this point in the history
Update to HiddenMarkovModels v0.6
  • Loading branch information
dmetivie authored Oct 25, 2024
2 parents e1f8faf + 4b6b1e7 commit cfa322e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name = "PeriodicHiddenMarkovModels"
uuid = "4873b48c-7fd1-4fb6-93c5-649b25bdde2e"
authors = ["David Métivier <[email protected]> and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
julia = "1.10"
julia = "1.10, 1.11"
ArgCheck = "2.3"
StatsAPI = "1.6"
HiddenMarkovModels = "0.5"
HiddenMarkovModels = "0.6"
1 change: 1 addition & 0 deletions src/PeriodicHiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ArgCheck: @argcheck
import HiddenMarkovModels as HMMs
import StatsAPI
using HiddenMarkovModels
using HiddenMarkovModels: AbstractVectorOrNTuple

export
# utilities.jl
Expand Down
14 changes: 7 additions & 7 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
n2t::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
) where {T}
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)

hmm.init .= zero(T)
for l in 1:L
hmm.trans[l] .= zero(T)
end
for k in eachindex(seq_ends)
@views for k in eachindex(seq_ends)
t1, t2 = HMMs.seq_limits(seq_ends, k)
hmm.init .+= γ[:, t1]
n2t_k = view(n2t, t1:t2)
n2t_k = n2t[t1:t2]

for l in 1:L
hmm.trans[l] .+= sum(ξ[first(parentindices(n2t_k))[findall(n2t_k .== l)]])
Expand All @@ -27,9 +27,9 @@ function StatsAPI.fit!(
end
for l in 1:L
times_l = Int[]
for k in eachindex(seq_ends)
@views for k in eachindex(seq_ends)
t1, t2 = HMMs.seq_limits(seq_ends, k)
n2t_k = view(n2t, t1:t2)
n2t_k = n2t[t1:t2]
append!(times_l, first(parentindices(n2t_k))[findall(n2t_k .== l)])
end
for i in 1:N
Expand All @@ -40,4 +40,4 @@ function StatsAPI.fit!(
@assert HMMs.valid_hmm(hmm, l)
end
return nothing
end
end

0 comments on commit cfa322e

Please sign in to comment.