diff --git a/Project.toml b/Project.toml index 01a80cd..31ac1c3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GenParticleFilters" uuid = "56b76ac4-72ef-411e-b419-6d312ed86a6f" authors = ["Xuan "] -version = "0.1.4" +version = "0.1.5" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/GenParticleFilters.jl b/src/GenParticleFilters.jl index 706cbb7..953d456 100644 --- a/src/GenParticleFilters.jl +++ b/src/GenParticleFilters.jl @@ -3,8 +3,9 @@ module GenParticleFilters using Gen, Distributions using Gen: ParticleFilterState -export ParticleFilterState +export ParticleFilterState, ParticleFilterSubState, ParticleFilterView +include("view.jl") include("utils.jl") include("initialize.jl") include("update.jl") diff --git a/src/rejuvenate.jl b/src/rejuvenate.jl index bf15b6b..020bd03 100644 --- a/src/rejuvenate.jl +++ b/src/rejuvenate.jl @@ -12,7 +12,7 @@ a tuple with a trace as the first return value. `method` specifies the rejuvenation method: `:move` for MCMC moves without a reweighting step, and `:reweight` for rejuvenation with a reweighting step. """ -function pf_rejuvenate!(state::ParticleFilterState, kern, kern_args::Tuple=(), +function pf_rejuvenate!(state::ParticleFilterView, kern, kern_args::Tuple=(), n_iters::Int=1; method::Symbol=:move) if method == :move return pf_move_accept!(state, kern, kern_args, n_iters) @@ -34,7 +34,7 @@ a tuple `(trace, accept)`, where `trace` is the (potentially) new trace, and can be supplied with `kern_args`. The kernel is repeatedly applied to each trace for `n_iters`. """ -function pf_move_accept!(state::ParticleFilterState, +function pf_move_accept!(state::ParticleFilterView, kern, kern_args::Tuple=(), n_iters::Int=1) # Potentially rejuvenate each trace for (i, trace) in enumerate(state.traces) @@ -44,10 +44,7 @@ function pf_move_accept!(state::ParticleFilterState, end state.new_traces[i] = trace end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end @@ -66,7 +63,7 @@ accumulated accordingly. [1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for online inference," Preprint series. Statistical Research Report, 2013. """ -function pf_move_reweight!(state::ParticleFilterState, +function pf_move_reweight!(state::ParticleFilterView, kern, kern_args::Tuple=(), n_iters::Int=1) # Move and reweight each trace for (i, trace) in enumerate(state.traces) @@ -79,10 +76,7 @@ function pf_move_reweight!(state::ParticleFilterState, state.new_traces[i] = trace state.log_weights[i] += weight end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end diff --git a/src/resample.jl b/src/resample.jl index 70ba79d..0243d06 100644 --- a/src/resample.jl +++ b/src/resample.jl @@ -60,10 +60,7 @@ function pf_multinomial_resample!(state::ParticleFilterState; ws = state.log_weights[state.parents] .- log_priorities[state.parents] state.log_weights = ws .+ (log(n_particles) - logsumexp(ws)) end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end @@ -115,10 +112,7 @@ function pf_residual_resample!(state::ParticleFilterState; ws = state.log_weights[state.parents] .- log_priorities[state.parents] state.log_weights = ws .+ (log(n_particles) - logsumexp(ws)) end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end @@ -170,9 +164,6 @@ function pf_stratified_resample!(state::ParticleFilterState; ws = state.log_weights[state.parents] .- log_priorities[state.parents] state.log_weights = ws .+ (log(n_particles) - logsumexp(ws)) end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end diff --git a/src/update.jl b/src/update.jl index 4d57b23..ced806f 100644 --- a/src/update.jl +++ b/src/update.jl @@ -9,7 +9,7 @@ Perform a particle filter update, where the model arguments are adjusted and new observations are conditioned upon. New latent choices are sampled from the model's default proposal. """ -function pf_update!(state::ParticleFilterState, new_args::Tuple, +function pf_update!(state::ParticleFilterView, new_args::Tuple, argdiffs::Tuple, observations::ChoiceMap) n_particles = length(state.traces) for i=1:n_particles @@ -20,10 +20,7 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple, end state.log_weights[i] += increment end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end @@ -51,7 +48,7 @@ that occur in `a` also occur in `b`, and the values at those addresses are equal. It is an error if no trace `t_new` satisfying the above conditions exists in the support of the model (with the new arguments). """ -function pf_update!(state::ParticleFilterState, new_args::Tuple, +function pf_update!(state::ParticleFilterView, new_args::Tuple, argdiffs::Tuple, observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple) n_particles = length(state.traces) @@ -66,10 +63,7 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple, end state.log_weights[i] += up_weight - prop_weight end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end @@ -107,7 +101,7 @@ calls to `pf_update!`). Similar functionality is provided by [`move_reweight`](@ref), except that `pf_update!` also allows model arguments to be updated. """ -function pf_update!(state::ParticleFilterState, new_args::Tuple, +function pf_update!(state::ParticleFilterView, new_args::Tuple, argdiffs::Tuple, observations::ChoiceMap, fwd_proposal::GenerativeFunction, fwd_args::Tuple, bwd_proposal::GenerativeFunction, bwd_args::Tuple) @@ -122,9 +116,6 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple, assess(bwd_proposal, (state.new_traces[i], bwd_args...), discard) state.log_weights[i] += up_weight - fwd_weight + bwd_weight end - # Swap references - tmp = state.traces - state.traces = state.new_traces - state.new_traces = tmp + update_refs!(state) return state end diff --git a/src/utils.jl b/src/utils.jl index bd5d6f3..cfab586 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,6 +6,17 @@ export mean, var using Gen: effective_sample_size using Statistics +@inline function update_refs!(state::ParticleFilterState) + # Swap references + tmp = state.traces + state.traces = state.new_traces + state.new_traces = tmp +end + +@inline function update_refs!(state::ParticleFilterSubState) + state.traces[:] = state.new_traces +end + lognorm(v::AbstractVector) = v .- logsumexp(v) """ @@ -14,7 +25,7 @@ lognorm(v::AbstractVector) = v .- logsumexp(v) Return the vector of normalized log weights for the current state, one for each particle. """ -get_log_norm_weights(state::ParticleFilterState) = lognorm(state.log_weights) +get_log_norm_weights(state::ParticleFilterView) = lognorm(state.log_weights) """ get_norm_weights(state::ParticleFilterState) @@ -22,14 +33,14 @@ get_log_norm_weights(state::ParticleFilterState) = lognorm(state.log_weights) Return the vector of normalized weights for the current state, one for each particle. """ -get_norm_weights(state::ParticleFilterState) = exp.(get_log_norm_weights(state)) +get_norm_weights(state::ParticleFilterView) = exp.(get_log_norm_weights(state)) """ effective_sample_size(state::ParticleFilterState) Computes the effective sample size of the particles in the filter. """ -Gen.effective_sample_size(state::ParticleFilterState) = +Gen.effective_sample_size(state::ParticleFilterView) = Gen.effective_sample_size(get_log_norm_weights(state)) """ @@ -37,7 +48,7 @@ Gen.effective_sample_size(state::ParticleFilterState) = Alias for `effective_sample_size`(@ref). Computes the effective sample size. """ -get_ess(state::ParticleFilterState) = Gen.effective_sample_size(state) +get_ess(state::ParticleFilterView) = Gen.effective_sample_size(state) """ mean(state::ParticleFilterState[, addr]) @@ -45,10 +56,10 @@ get_ess(state::ParticleFilterState) = Gen.effective_sample_size(state) Returns the weighted empirical mean for a particular trace address `addr`. If `addr` is not provided, returns the empirical mean of the return value. """ -Statistics.mean(state::ParticleFilterState, addr) = +Statistics.mean(state::ParticleFilterView, addr) = sum(get_norm_weights(state) .* getindex.(state.traces, addr)) -Statistics.mean(state::ParticleFilterState) = +Statistics.mean(state::ParticleFilterView) = sum(get_norm_weights(state) .* get_retval.(state.traces)) """ @@ -57,10 +68,10 @@ Statistics.mean(state::ParticleFilterState) = Returns the empirical variance for a particular trace address `addr`. If `addr` is not provided, returns the empirical variance of the return value. """ -Statistics.var(state::ParticleFilterState, addr) = +Statistics.var(state::ParticleFilterView, addr) = sum(get_norm_weights(state) .* (getindex.(state.traces, addr) .- mean(state, addr)).^2) -Statistics.var(state::ParticleFilterState) = +Statistics.var(state::ParticleFilterView) = sum(get_norm_weights(state) .* (get_retval.(state.traces) .- mean(state)).^2) diff --git a/src/view.jl b/src/view.jl new file mode 100644 index 0000000..e9eed55 --- /dev/null +++ b/src/view.jl @@ -0,0 +1,29 @@ +struct ParticleFilterSubState{U,I,L} + traces::SubArray{U,1,Vector{U},I,L} + new_traces::SubArray{U,1,Vector{U},I,L} + log_weights::SubArray{Float64,1,Vector{Float64},I,L} + parents::SubArray{Int,1,Vector{Int},I,L} +end + +Gen.get_traces(state::ParticleFilterSubState) = state.traces +Gen.get_log_weights(state::ParticleFilterSubState) = state.log_weights + +const ParticleFilterView{U} = + Union{ParticleFilterState{U}, ParticleFilterSubState{U}} where {U} + +function Base.view(state::ParticleFilterState{U}, + indices::AbstractVector) where {U} + L = Base.viewindexing((indices,)) == IndexLinear() + return ParticleFilterSubState{U,Tuple{typeof(indices)},L}( + view(state.traces, indices), + view(state.new_traces, indices), + view(state.log_weights, indices), + view(state.parents, indices) + ) +end + +Base.getindex(state::ParticleFilterState, indices) = + Base.view(state, indices) + +Base.firstindex(state::ParticleFilterState) = 1 +Base.lastindex(state::ParticleFilterState) = length(state.traces) diff --git a/test/runtests.jl b/test/runtests.jl index a46b43d..e708262 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,6 +87,18 @@ state = pf_update!(state, (10,), (UnknownChange(),), choicemap(), @test all([w != 0 for w in get_log_weights(state)]) end +@testset "Update with different proposals per view" begin +state = pf_initialize(line_model, (0,), choicemap(), 100) +substate = pf_update!(state[1:50], (10,), (UnknownChange(),), generate_line(10)) +@test all([tr[:line => 10 => :y] == 0 for tr in get_traces(substate)]) +@test all([w != 0 for w in get_log_weights(substate)]) +substate = pf_update!(state[51:end], (10,), (UnknownChange(),), + generate_line(10), outlier_propose, (10,)) +@test all([tr[:line => 10 => :y] == 0 for tr in get_traces(substate)]) +@test all([tr[:line => 10 => :outlier] == false for tr in get_traces(substate)]) +@test all([w != 0 for w in get_log_weights(state)]) +end + end @testset "Particle resampling" begin @@ -240,6 +252,38 @@ rel_weights = parse.(Float64, rel_weights) @test all(isapprox.(new_weights, old_weights .+ rel_weights; atol=1e-3)) end +@testset "Rejuvenation on separate views" begin +# Log which particles were rejuvenated +buffer = IOBuffer() +logger = SimpleLogger(buffer, Logging.Debug) +state = pf_initialize(line_model, (10,), generate_line(10, 1.), 100) +old_traces = get_traces(state)[1:50] +old_weights = get_log_weights(state)[51:end] + +with_logger(logger) do + pf_move_accept!(state[1:50], metropolis_hastings, (select(:slope),), 1) + pf_move_reweight!(state[51:end], move_reweight, (select(:slope),), 1) +end + +# Extract acceptances and relative weights from debug log +lines = split(String(take!(buffer)), "\n") +a_lines = filter(s -> occursin("Accepted: ", s), lines) +accepts = [match(r".*Accepted: (\w+).*", l).captures[1] for l in a_lines] +accepts = parse.(Bool, accepts) +r_lines = filter(s -> occursin("Rel. Weight: ", s), lines) +rel_weights = [match(r".*Rel\. Weight: (.+)\s*", l).captures[1] for l in r_lines] +rel_weights = parse.(Float64, rel_weights) + +# Check that only traces that were accepted are rejuvenated +new_traces = get_traces(state)[1:50] +@test all(a ? t1 !== t2 : t1 === t2 + for (a, t1, t2) in zip(accepts, old_traces, new_traces)) +# Check that weights are adjusted accordingly +new_weights = get_log_weights(state)[51:end] +@test all(isapprox.(new_weights, old_weights .+ rel_weights; atol=1e-3)) + +end + end @testset "Utility functions" begin