Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
topipa committed Sep 10, 2023
1 parent 35fee1f commit a72ccad
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 64 deletions.
125 changes: 65 additions & 60 deletions R/brmsfit_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,31 +43,37 @@ moment_match.brmsfit <- function(x,

# ensure draws are in matrix form
draws <- posterior::as_draws_matrix(x)
# draws <- as.matrix(draws)

# if (!is.null(target_observation_weights)) {
# out <- tryCatch(posterior::subset_draws(draws, variable = "log_lik"),
# error = function(cond) {
# message(cond)
# message("\nYour stan fit does not include a parameter called log_lik.")
# message("To use target_observation_weights, you must define log_lik in the generated quantities block.")
# return(NA)
# }
# )
#
# log_ratio_fun <- function(draws, fit, ...) {
# cdraws <- constrain_draws(fit, draws)
# ll <- posterior::merge_chains(
# posterior::subset_draws(cdraws, variable = "log_lik")
# )
# colSums(t(drop(ll)) * (target_observation_weights - 1))
# }
# }
if (!is.null(target_observation_weights)) {
out <- tryCatch(log_lik(x),
error = function(cond) {
message(cond)
message("\nYour brmsfit does not include a parameter called log_lik.")
message("This should not happen. Perhaps you are using an unsupported observation model?")
return(NA)
}
)

function(draws, fit, extra_data, ...) {
fit <- brms:::.update_pars(x = fit, upars = draws)
ll <- log_lik(fit, newdata = extra_data)
rowSums(ll)
}

log_ratio_fun <- function(draws, fit, ...) {
fit <- brms:::.update_pars(x = fit, upars = draws)
ll <- log_lik(fit)
colSums(t(drop(ll)) * (target_observation_weights - 1))
}
}


# transform the draws to unconstrained space
udraws <- unconstrain_draws.brmsfit(x, draws = draws, ...)

out <- moment_match.matrix(
# as.matrix(udraws),
udraws,
log_prob_prop_fun = log_prob_draws.brmsfit,
log_prob_target_fun = log_prob_target_fun,
Expand All @@ -86,7 +92,7 @@ moment_match.brmsfit <- function(x,
out$draws <- constrain_draws.stanfit(x$fit, out$draws, ...)
}

list(iwmm_object = out,
list(adapted_importance_sampling = out,
brmsfit_object = x)
}

Expand All @@ -107,47 +113,46 @@ constrain_draws.brmsfit <- function(x, udraws, ...) {
}

# # transform parameters to the constraint space
# update_pars_brmsfit <- function(x, draws, ...) {
# # list with one element per posterior draw
# pars <- apply(draws, 1, constrain_draws.brmsfit, x = x)
# # select required parameters only
# pars <- lapply(pars, "[", x$fit@sim$pars_oi_old)
# # transform draws
# ndraws <- length(pars)
# pars <- unlist(pars)
# npars <- length(pars) / ndraws
# dim(pars) <- c(npars, ndraws)
# # add dummy 'lp__' draws
# pars <- rbind(pars, rep(0, ndraws))
# # bring draws into the right structure
# new_draws <- named_list(x$fit@sim$fnames_oi_old, list(numeric(ndraws)))
# if (length(new_draws) != nrow(pars)) {
# stop2("Updating parameters in `update_pars_brmsfit' failed. ",
# "Please report a bug at https://github.com/paul-buerkner/brms.")
# }
# for (i in seq_len(npars)) {
# new_draws[[i]] <- pars[i, ]
# }
# # create new sim object to overwrite x$fit@sim
# x$fit@sim <- list(
# samples = list(new_draws),
# iter = ndraws,
# thin = 1,
# warmup = 0,
# chains = 1,
# n_save = ndraws,
# warmup2 = 0,
# permutation = list(seq_len(ndraws)),
# pars_oi = x$fit@sim$pars_oi_old,
# dims_oi = x$fit@sim$dims_oi_old,
# fnames_oi = x$fit@sim$fnames_oi_old,
# n_flatnames = length(x$fit@sim$fnames_oi_old)
# )
# x$fit@stan_args <- list(
# list(chain_id = 1, iter = ndraws, thin = 1, warmup = 0)
# )
# brms::rename_pars(x)
# }
update_pars_brmsfit <- function(x, draws, ...) {
# list with one element per posterior draw
pars <- apply(draws, 1, constrain_draws.brmsfit, x = x)
# select required parameters only
pars <- lapply(pars, "[", x$fit@sim$pars_oi_old)
# transform draws
ndraws <- length(pars)
pars <- unlist(pars)
npars <- length(pars) / ndraws
dim(pars) <- c(npars, ndraws)
# add dummy 'lp__' draws
pars <- rbind(pars, rep(0, ndraws))
# bring draws into the right structure
new_draws <- named_list(x$fit@sim$fnames_oi_old, list(numeric(ndraws)))
if (length(new_draws) != nrow(pars)) {
stop2("Updating parameters in `update_pars_brmsfit' failed.")
}
for (i in seq_len(npars)) {
new_draws[[i]] <- pars[i, ]
}
# create new sim object to overwrite x$fit@sim
x$fit@sim <- list(
samples = list(new_draws),
iter = ndraws,
thin = 1,
warmup = 0,
chains = 1,
n_save = ndraws,
warmup2 = 0,
permutation = list(seq_len(ndraws)),
pars_oi = x$fit@sim$pars_oi_old,
dims_oi = x$fit@sim$dims_oi_old,
fnames_oi = x$fit@sim$fnames_oi_old,
n_flatnames = length(x$fit@sim$fnames_oi_old)
)
x$fit@stan_args <- list(
list(chain_id = 1, iter = ndraws, thin = 1, warmup = 0)
)
brms::rename_pars(x)
}

# update .MISC environment of the stanfit object
# allows to call log_prob and other C++ using methods
Expand Down
7 changes: 3 additions & 4 deletions tests/testthat/test-moment-match-brmsfit-analytical.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ if (brms_available) {



test_that("moment_match.stanfit matches analytical results", {
test_that("moment_match.brmsfit matches analytical results", {
# TODO: implement this test with expectation_fun

joint_log_lik_extra_data <- function(draws, fit, extra_data, ...) {
Expand All @@ -137,13 +137,12 @@ if (brms_available) {
k_threshold = -Inf # ensure moment-matching is used
)


draws_mm_single_obs <- posterior::subset_draws(
posterior::as_draws_matrix(iw_single_obs$draws),
posterior::as_draws_matrix(iw_single_obs$adapted_importance_sampling$draws),
variable = c("b_Intercept", "sigma_sq")
)

weights_mm_single_obs <- exp(iw_single_obs$log_weights)
weights_mm_single_obs <- exp(iw_single_obs$adapted_importance_sampling$log_weights)
mean_mm_single_obs <- matrixStats::colWeightedMeans(
draws_mm_single_obs,
w = weights_mm_single_obs
Expand Down

0 comments on commit a72ccad

Please sign in to comment.