Skip to content

Commit

Permalink
Support updates and rejuvenation for sub-views of particle filters.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Mar 22, 2021
1 parent 0269780 commit 7aa1fdf
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GenParticleFilters"
uuid = "56b76ac4-72ef-411e-b419-6d312ed86a6f"
authors = ["Xuan <[email protected]>"]
version = "0.1.4"
version = "0.1.5"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
3 changes: 2 additions & 1 deletion src/GenParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module GenParticleFilters
using Gen, Distributions
using Gen: ParticleFilterState

export ParticleFilterState
export ParticleFilterState, ParticleFilterSubState, ParticleFilterView

include("view.jl")
include("utils.jl")
include("initialize.jl")
include("update.jl")
Expand Down
16 changes: 5 additions & 11 deletions src/rejuvenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ a tuple with a trace as the first return value. `method` specifies the
rejuvenation method: `:move` for MCMC moves without a reweighting step,
and `:reweight` for rejuvenation with a reweighting step.
"""
function pf_rejuvenate!(state::ParticleFilterState, kern, kern_args::Tuple=(),
function pf_rejuvenate!(state::ParticleFilterView, kern, kern_args::Tuple=(),
n_iters::Int=1; method::Symbol=:move)
if method == :move
return pf_move_accept!(state, kern, kern_args, n_iters)
Expand All @@ -34,7 +34,7 @@ a tuple `(trace, accept)`, where `trace` is the (potentially) new trace, and
can be supplied with `kern_args`. The kernel is repeatedly applied to each trace
for `n_iters`.
"""
function pf_move_accept!(state::ParticleFilterState,
function pf_move_accept!(state::ParticleFilterView,
kern, kern_args::Tuple=(), n_iters::Int=1)
# Potentially rejuvenate each trace
for (i, trace) in enumerate(state.traces)
Expand All @@ -44,10 +44,7 @@ function pf_move_accept!(state::ParticleFilterState,
end
state.new_traces[i] = trace
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end

Expand All @@ -66,7 +63,7 @@ accumulated accordingly.
[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
online inference," Preprint series. Statistical Research Report, 2013.
"""
function pf_move_reweight!(state::ParticleFilterState,
function pf_move_reweight!(state::ParticleFilterView,
kern, kern_args::Tuple=(), n_iters::Int=1)
# Move and reweight each trace
for (i, trace) in enumerate(state.traces)
Expand All @@ -79,10 +76,7 @@ function pf_move_reweight!(state::ParticleFilterState,
state.new_traces[i] = trace
state.log_weights[i] += weight
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end

Expand Down
15 changes: 3 additions & 12 deletions src/resample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ function pf_multinomial_resample!(state::ParticleFilterState;
ws = state.log_weights[state.parents] .- log_priorities[state.parents]
state.log_weights = ws .+ (log(n_particles) - logsumexp(ws))
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end

Expand Down Expand Up @@ -115,10 +112,7 @@ function pf_residual_resample!(state::ParticleFilterState;
ws = state.log_weights[state.parents] .- log_priorities[state.parents]
state.log_weights = ws .+ (log(n_particles) - logsumexp(ws))
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end

Expand Down Expand Up @@ -170,9 +164,6 @@ function pf_stratified_resample!(state::ParticleFilterState;
ws = state.log_weights[state.parents] .- log_priorities[state.parents]
state.log_weights = ws .+ (log(n_particles) - logsumexp(ws))
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end
21 changes: 6 additions & 15 deletions src/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Perform a particle filter update, where the model arguments are adjusted and
new observations are conditioned upon. New latent choices are sampled from
the model's default proposal.
"""
function pf_update!(state::ParticleFilterState, new_args::Tuple,
function pf_update!(state::ParticleFilterView, new_args::Tuple,
argdiffs::Tuple, observations::ChoiceMap)
n_particles = length(state.traces)
for i=1:n_particles
Expand All @@ -20,10 +20,7 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple,
end
state.log_weights[i] += increment
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end

Expand Down Expand Up @@ -51,7 +48,7 @@ that occur in `a` also occur in `b`, and the values at those addresses are
equal. It is an error if no trace `t_new` satisfying the above conditions
exists in the support of the model (with the new arguments).
"""
function pf_update!(state::ParticleFilterState, new_args::Tuple,
function pf_update!(state::ParticleFilterView, new_args::Tuple,
argdiffs::Tuple, observations::ChoiceMap,
proposal::GenerativeFunction, proposal_args::Tuple)
n_particles = length(state.traces)
Expand All @@ -66,10 +63,7 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple,
end
state.log_weights[i] += up_weight - prop_weight
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end

Expand Down Expand Up @@ -107,7 +101,7 @@ calls to `pf_update!`).
Similar functionality is provided by [`move_reweight`](@ref), except that
`pf_update!` also allows model arguments to be updated.
"""
function pf_update!(state::ParticleFilterState, new_args::Tuple,
function pf_update!(state::ParticleFilterView, new_args::Tuple,
argdiffs::Tuple, observations::ChoiceMap,
fwd_proposal::GenerativeFunction, fwd_args::Tuple,
bwd_proposal::GenerativeFunction, bwd_args::Tuple)
Expand All @@ -122,9 +116,6 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple,
assess(bwd_proposal, (state.new_traces[i], bwd_args...), discard)
state.log_weights[i] += up_weight - fwd_weight + bwd_weight
end
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
update_refs!(state)
return state
end
27 changes: 19 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ export mean, var
using Gen: effective_sample_size
using Statistics

@inline function update_refs!(state::ParticleFilterState)
# Swap references
tmp = state.traces
state.traces = state.new_traces
state.new_traces = tmp
end

@inline function update_refs!(state::ParticleFilterSubState)
state.traces[:] = state.new_traces
end

lognorm(v::AbstractVector) = v .- logsumexp(v)

"""
Expand All @@ -14,41 +25,41 @@ lognorm(v::AbstractVector) = v .- logsumexp(v)
Return the vector of normalized log weights for the current state,
one for each particle.
"""
get_log_norm_weights(state::ParticleFilterState) = lognorm(state.log_weights)
get_log_norm_weights(state::ParticleFilterView) = lognorm(state.log_weights)

"""
get_norm_weights(state::ParticleFilterState)
Return the vector of normalized weights for the current state,
one for each particle.
"""
get_norm_weights(state::ParticleFilterState) = exp.(get_log_norm_weights(state))
get_norm_weights(state::ParticleFilterView) = exp.(get_log_norm_weights(state))

"""
effective_sample_size(state::ParticleFilterState)
Computes the effective sample size of the particles in the filter.
"""
Gen.effective_sample_size(state::ParticleFilterState) =
Gen.effective_sample_size(state::ParticleFilterView) =
Gen.effective_sample_size(get_log_norm_weights(state))

"""
get_ess(state::ParticleFilterState)
Alias for `effective_sample_size`(@ref). Computes the effective sample size.
"""
get_ess(state::ParticleFilterState) = Gen.effective_sample_size(state)
get_ess(state::ParticleFilterView) = Gen.effective_sample_size(state)

"""
mean(state::ParticleFilterState[, addr])
Returns the weighted empirical mean for a particular trace address `addr`.
If `addr` is not provided, returns the empirical mean of the return value.
"""
Statistics.mean(state::ParticleFilterState, addr) =
Statistics.mean(state::ParticleFilterView, addr) =
sum(get_norm_weights(state) .* getindex.(state.traces, addr))

Statistics.mean(state::ParticleFilterState) =
Statistics.mean(state::ParticleFilterView) =
sum(get_norm_weights(state) .* get_retval.(state.traces))

"""
Expand All @@ -57,10 +68,10 @@ Statistics.mean(state::ParticleFilterState) =
Returns the empirical variance for a particular trace address `addr`.
If `addr` is not provided, returns the empirical variance of the return value.
"""
Statistics.var(state::ParticleFilterState, addr) =
Statistics.var(state::ParticleFilterView, addr) =
sum(get_norm_weights(state) .*
(getindex.(state.traces, addr) .- mean(state, addr)).^2)

Statistics.var(state::ParticleFilterState) =
Statistics.var(state::ParticleFilterView) =
sum(get_norm_weights(state) .*
(get_retval.(state.traces) .- mean(state)).^2)
29 changes: 29 additions & 0 deletions src/view.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
struct ParticleFilterSubState{U,I,L}
traces::SubArray{U,1,Vector{U},I,L}
new_traces::SubArray{U,1,Vector{U},I,L}
log_weights::SubArray{Float64,1,Vector{Float64},I,L}
parents::SubArray{Int,1,Vector{Int},I,L}
end

Gen.get_traces(state::ParticleFilterSubState) = state.traces
Gen.get_log_weights(state::ParticleFilterSubState) = state.log_weights

const ParticleFilterView{U} =
Union{ParticleFilterState{U}, ParticleFilterSubState{U}} where {U}

function Base.view(state::ParticleFilterState{U},
indices::AbstractVector) where {U}
L = Base.viewindexing((indices,)) == IndexLinear()
return ParticleFilterSubState{U,Tuple{typeof(indices)},L}(
view(state.traces, indices),
view(state.new_traces, indices),
view(state.log_weights, indices),
view(state.parents, indices)
)
end

Base.getindex(state::ParticleFilterState, indices) =
Base.view(state, indices)

Base.firstindex(state::ParticleFilterState) = 1
Base.lastindex(state::ParticleFilterState) = length(state.traces)
44 changes: 44 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ state = pf_update!(state, (10,), (UnknownChange(),), choicemap(),
@test all([w != 0 for w in get_log_weights(state)])
end

@testset "Update with different proposals per view" begin
state = pf_initialize(line_model, (0,), choicemap(), 100)
substate = pf_update!(state[1:50], (10,), (UnknownChange(),), generate_line(10))
@test all([tr[:line => 10 => :y] == 0 for tr in get_traces(substate)])
@test all([w != 0 for w in get_log_weights(substate)])
substate = pf_update!(state[51:end], (10,), (UnknownChange(),),
generate_line(10), outlier_propose, (10,))
@test all([tr[:line => 10 => :y] == 0 for tr in get_traces(substate)])
@test all([tr[:line => 10 => :outlier] == false for tr in get_traces(substate)])
@test all([w != 0 for w in get_log_weights(state)])
end

end

@testset "Particle resampling" begin
Expand Down Expand Up @@ -240,6 +252,38 @@ rel_weights = parse.(Float64, rel_weights)
@test all(isapprox.(new_weights, old_weights .+ rel_weights; atol=1e-3))
end

@testset "Rejuvenation on separate views" begin
# Log which particles were rejuvenated
buffer = IOBuffer()
logger = SimpleLogger(buffer, Logging.Debug)
state = pf_initialize(line_model, (10,), generate_line(10, 1.), 100)
old_traces = get_traces(state)[1:50]
old_weights = get_log_weights(state)[51:end]

with_logger(logger) do
pf_move_accept!(state[1:50], metropolis_hastings, (select(:slope),), 1)
pf_move_reweight!(state[51:end], move_reweight, (select(:slope),), 1)
end

# Extract acceptances and relative weights from debug log
lines = split(String(take!(buffer)), "\n")
a_lines = filter(s -> occursin("Accepted: ", s), lines)
accepts = [match(r".*Accepted: (\w+).*", l).captures[1] for l in a_lines]
accepts = parse.(Bool, accepts)
r_lines = filter(s -> occursin("Rel. Weight: ", s), lines)
rel_weights = [match(r".*Rel\. Weight: (.+)\s*", l).captures[1] for l in r_lines]
rel_weights = parse.(Float64, rel_weights)

# Check that only traces that were accepted are rejuvenated
new_traces = get_traces(state)[1:50]
@test all(a ? t1 !== t2 : t1 === t2
for (a, t1, t2) in zip(accepts, old_traces, new_traces))
# Check that weights are adjusted accordingly
new_weights = get_log_weights(state)[51:end]
@test all(isapprox.(new_weights, old_weights .+ rel_weights; atol=1e-3))

end

end

@testset "Utility functions" begin
Expand Down

2 comments on commit 7aa1fdf

@ztangent
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32560

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.5 -m "<description of version>" 7aa1fdf049711e521d7be0c5e30b98b3f1816b60
git push origin v0.1.5

Please sign in to comment.