Skip to content

Commit

Permalink
More informative translator fieldname for trace transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Apr 2, 2021
1 parent fbf8b4b commit 3394832
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
37 changes: 18 additions & 19 deletions src/translate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Gen: run_first_pass, jacobian_correction, check_round_trip, run_transform
new_observations::ChoiceMap = EmptyChoiceMap(),
q_forward::GenerativeFunction,
q_forward_args::Tuple = (),
f::Union{TraceTransformDSLProgram,Nothing} = nothing)
transform::Union{TraceTransformDSLProgram,Nothing} = nothing)
Constructor for a extending trace translator.
Run the translator with:
(output_trace, log_weight) = translator(input_trace)
Expand All @@ -22,7 +22,7 @@ Run the translator with:
new_observations::ChoiceMap = EmptyChoiceMap()
q_forward::GenerativeFunction
q_forward_args::Tuple = ()
f::Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
transform::Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
end

function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace)
Expand All @@ -33,14 +33,14 @@ function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace)
forward_proposal_score = get_score(forward_proposal_trace)

# transform forward proposal
if translator.f === nothing
if translator.transform === nothing
constraints = get_choices(forward_proposal_trace)
log_abs_determinant = 0.0
else
first_pass_results =
run_first_pass(translator.f, forward_proposal_trace, nothing)
run_first_pass(translator.transform, forward_proposal_trace, nothing)
log_abs_determinant =
jacobian_correction(translator.f, forward_proposal_trace,
jacobian_correction(translator.transform, forward_proposal_trace,
nothing, first_pass_results, nothing)
constraints = first_pass_results.constraints
end
Expand Down Expand Up @@ -97,7 +97,7 @@ the observed random choices in the previous trace.
q_forward_args::Tuple = ()
q_backward::GenerativeFunction
q_backward_args::Tuple = ()
f::TraceTransformDSLProgram
transform::TraceTransformDSLProgram
end

function Gen.inverse(translator::UpdatingTraceTranslator, prev_model_trace::Trace,
Expand All @@ -106,23 +106,22 @@ function Gen.inverse(translator::UpdatingTraceTranslator, prev_model_trace::Trac
get_args(prev_model_trace), map((_)->UnknownChange(), get_args(prev_model_trace)),
prev_observations, translator.q_backward, translator.q_backward_args,
translator.q_forward, translator.q_forward_args,
inverse(translator.f))
inverse(translator.transform))
end

function Gen.run_transform(translator::UpdatingTraceTranslator,
prev_model_trace::Trace, forward_proposal_trace::Trace,
check::Bool=false)
@unpack f, new_observations = translator
prev_model_trace::Trace, forward_proposal_trace::Trace)
@unpack transform, new_observations = translator
@unpack p_new_args, p_argdiffs, q_backward, q_backward_args = translator
first_pass_results =
Gen.run_first_pass(f, prev_model_trace, forward_proposal_trace)
first_pass_results = run_first_pass(
transform, prev_model_trace, forward_proposal_trace)
constraints = merge(first_pass_results.constraints, new_observations)
(new_model_trace, _, _, discard) = update(
new_model_trace, _, _, discard = update(
prev_model_trace, p_new_args, p_argdiffs, constraints)
log_abs_determinant = jacobian_correction(f, prev_model_trace,
forward_proposal_trace, first_pass_results, discard)
backward_proposal_trace, = generate(q_backward,
(new_model_trace, q_backward_args...), first_pass_results.u_back)
log_abs_determinant = jacobian_correction(
transform, prev_model_trace, forward_proposal_trace, first_pass_results, discard)
backward_proposal_trace, _ = generate(
q_backward, (new_model_trace, q_backward_args...), first_pass_results.u_back)
return (new_model_trace, backward_proposal_trace, log_abs_determinant)
end

Expand All @@ -135,7 +134,7 @@ function (translator::UpdatingTraceTranslator)(

# apply trace transform
(new_model_trace, backward_proposal_trace, log_abs_determinant) =
run_transform(translator, prev_model_trace, forward_proposal_trace, check)
run_transform(translator, prev_model_trace, forward_proposal_trace)

# compute log weight
prev_model_score = get_score(prev_model_trace)
Expand All @@ -149,7 +148,7 @@ function (translator::UpdatingTraceTranslator)(
inverter = inverse(translator, prev_model_trace, prev_observations)
argdiffs = map((_) -> UnknownChange(), get_args(prev_model_trace))
(prev_model_trace_rt, forward_proposal_trace_rt, _) =
run_transform(inverter, new_model_trace, backward_proposal_trace, check)
run_transform(inverter, new_model_trace, backward_proposal_trace)
check_round_trip(prev_model_trace, prev_model_trace_rt,
forward_proposal_trace, forward_proposal_trace_rt)
end
Expand Down
4 changes: 2 additions & 2 deletions src/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function pf_update!(state::ParticleFilterView, new_args::Tuple,
n_particles = length(state.traces)
translator = GenParticleFilters.ExtendingTraceTranslator(
p_new_args=new_args, p_argdiffs=argdiffs, new_observations=observations,
q_forward=proposal, q_forward_args=proposal_args, f=transform)
q_forward=proposal, q_forward_args=proposal_args, transform=transform)
return pf_update!(state, translator)
end

Expand Down Expand Up @@ -169,6 +169,6 @@ function pf_update!(state::ParticleFilterView, new_args::Tuple,
translator = GenParticleFilters.UpdatingTraceTranslator(
p_new_args=new_args, p_argdiffs=argdiffs, new_observations=observations,
q_forward=fwd_proposal, q_forward_args=fwd_args,
q_backward=bwd_proposal, q_backward_args=bwd_args, f=transform)
q_backward=bwd_proposal, q_backward_args=bwd_args, transform=transform)
return pf_update!(state, translator; check=check)
end

2 comments on commit 3394832

@ztangent
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33451

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.6 -m "<description of version>" 3394832987e43d4214bddf49fcdccbf4098c7c18
git push origin v0.1.6

Please sign in to comment.