Skip to content

Commit

Permalink
Merge branch 'brmsfit-method' into posterior-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Jan 18, 2024
2 parents 70a1936 + 7f7ce90 commit f3d4105
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 1 deletion.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Suggests:
rstan,
testthat
Enhances:
brms,
cmdstanr
VignetteBuilder: knitr
Config/testthat/parallel: true
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(constrain_draws,stanfit)
S3method(moment_match,CmdStanFit)
S3method(moment_match,brmsfit)
S3method(moment_match,draws_array)
S3method(moment_match,draws_df)
S3method(moment_match,draws_list)
Expand Down
176 changes: 176 additions & 0 deletions R/brmsfit_functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#' Generic importance weighted moment matching algorithm for `brmsfit` objects.
#' See additional arguments from `moment_match.matrix`
#'
#' @param x A fitted `brmsfit` object.
#' @param log_prob_target_fun Log density of the target. The function
#' takes argument `draws`, which are the unconstrained draws.
#' Can also take the argument `fit` which is the stan model fit.
#' @param log_ratio_fun Log of the density ratio (target/proposal).
#' The function takes argument `draws`, which are the unconstrained
#' draws. Can also take the argument `fit` which is the stan model fit.
#' @param target_observation_weights A vector of weights for observations for
#' defining the target distribution. A value 0 means dropping the observation,
#' a value 1 means including the observation similarly as in the current data,
#' and a value 2 means including the observation twice.
#' @param expectation_fun Optional argument, NULL by default. A
#' function whose expectation is being computed. The function takes
#' arguments `draws`.
#' @param log_expectation_fun Logical indicating whether the
#' expectation_fun returns its values as logarithms or not. Defaults
#' to FALSE. If set to TRUE, the expectation function must be
#' nonnegative (before taking the logarithm). Ignored if
#' `expectation_fun` is NULL.
#' @param constrain Logical specifying whether to return draws on the
#' constrained space? Default is TRUE.
#' @param ... Further arguments passed to `moment_match.matrix`.
#'
#' @return Returns a list with 3 elements: transformed draws, updated
#' importance weights, and the pareto k diagnostic value. If expectation_fun
#' is given, also returns the expectation.
#'
#' @export
moment_match.brmsfit <- function(x,
log_prob_target_fun = NULL,
log_ratio_fun = NULL,
target_observation_weights = NULL,
expectation_fun = NULL,
log_expectation_fun = FALSE,
constrain = TRUE,
...) {
if (!is.null(target_observation_weights) && (!is.null(log_prob_target_fun) || !is.null(log_ratio_fun))) {
stop("You must give only one of target_observation_weights, log_prob_target_fun, or log_ratio_fun.")
}

# ensure draws are in matrix form
draws <- posterior::as_draws_matrix(x)
# draws <- as.matrix(draws)

if (!is.null(target_observation_weights)) {
out <- tryCatch(log_lik(x),

Check warning on line 49 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=49,col=21,[object_usage_linter] no visible global function definition for 'log_lik'
error = function(cond) {

Check warning on line 50 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=50,col=20,[indentation_linter] Indentation should be 6 spaces but is 20 spaces.
message(cond)
message("\nYour brmsfit does not include a parameter called log_lik.")
message("This should not happen. Perhaps you are using an unsupported observation model?")
return(NA)
}
)

function(draws, fit, extra_data, ...) {
fit <- brms:::.update_pars(x = fit, upars = draws)
ll <- log_lik(fit, newdata = extra_data)

Check warning on line 60 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=60,col=13,[object_usage_linter] no visible global function definition for 'log_lik'
rowSums(ll)
}

log_ratio_fun <- function(draws, fit, ...) {
fit <- brms:::.update_pars(x = fit, upars = draws)
ll <- log_lik(fit)

Check warning on line 66 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=66,col=13,[object_usage_linter] no visible global function definition for 'log_lik'
colSums(t(drop(ll)) * (target_observation_weights - 1))
}
}


# transform the draws to unconstrained space
udraws <- unconstrain_draws.brmsfit(x, draws = draws, ...)

out <- moment_match.matrix(
# as.matrix(udraws),
udraws,
log_prob_prop_fun = log_prob_draws.brmsfit,
log_prob_target_fun = log_prob_target_fun,
log_ratio_fun = log_ratio_fun,
expectation_fun = expectation_fun,
log_expectation_fun = log_expectation_fun,
fit = x,
...
)

# TODO: this does not work for some reason
# x <- brms:::.update_pars(x = x, upars = out$draws)
# x <- update_pars_brmsfit(x = x, draws = out$draws)

if (constrain) {
out$draws <- constrain_draws.stanfit(x$fit, out$draws, ...)
}

list(adapted_importance_sampling = out,
brmsfit_object = x)
}


log_prob_draws.brmsfit <- function(fit, draws, ...) {
# x <- update_misc_env(x, only_windows = TRUE)
log_prob_draws.stanfit(fit$fit, draws = draws, ...)
}

unconstrain_draws.brmsfit <- function(x, draws, ...) {
unconstrain_draws.stanfit(x$fit, draws = draws, ...)
}

constrain_draws.brmsfit <- function(x, udraws, ...) {
out <- rstan::constrain_pars(udraws, object = x$fit)
out[x$exclude] <- NULL
out
}

# # transform parameters to the constraint space
update_pars_brmsfit <- function(x, draws, ...) {
# list with one element per posterior draw
pars <- apply(draws, 1, constrain_draws.brmsfit, x = x)
# select required parameters only
pars <- lapply(pars, "[", x$fit@sim$pars_oi_old)
# transform draws
ndraws <- length(pars)
pars <- unlist(pars)
npars <- length(pars) / ndraws
dim(pars) <- c(npars, ndraws)
# add dummy 'lp__' draws
pars <- rbind(pars, rep(0, ndraws))
# bring draws into the right structure
new_draws <- named_list(x$fit@sim$fnames_oi_old, list(numeric(ndraws)))
if (length(new_draws) != nrow(pars)) {
stop2("Updating parameters in `update_pars_brmsfit' failed.")

Check warning on line 131 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=131,col=5,[object_usage_linter] no visible global function definition for 'stop2'
}
for (i in seq_len(npars)) {
new_draws[[i]] <- pars[i, ]
}
# create new sim object to overwrite x$fit@sim
x$fit@sim <- list(
samples = list(new_draws),
iter = ndraws,
thin = 1,
warmup = 0,
chains = 1,
n_save = ndraws,
warmup2 = 0,
permutation = list(seq_len(ndraws)),
pars_oi = x$fit@sim$pars_oi_old,
dims_oi = x$fit@sim$dims_oi_old,
fnames_oi = x$fit@sim$fnames_oi_old,
n_flatnames = length(x$fit@sim$fnames_oi_old)
)
x$fit@stan_args <- list(
list(chain_id = 1, iter = ndraws, thin = 1, warmup = 0)
)
brms::rename_pars(x)
}

# update .MISC environment of the stanfit object
# allows to call log_prob and other C++ using methods
# on objects not created in the current R session
# or objects created via another backend
# update_misc_env <- function(x, recompile = FALSE, only_windows = FALSE) {
# stopifnot(is.brmsfit(x))
# recompile <- as_one_logical(recompile)
# only_windows <- as_one_logical(only_windows)
# if (recompile || !has_rstan_model(x)) {
# x <- add_rstan_model(x, overwrite = TRUE)
# } else if (os_is_windows() || !only_windows) {
# # TODO: detect when updating .MISC is not required
# # TODO: find a more efficient way to update .MISC
# old_backend <- x$backend
# x$backend <- "rstan"
# [email protected] <- suppressMessages(brm(fit = x, chains = 0))[email protected]
# x$backend <- old_backend
# }
# x
# }
58 changes: 58 additions & 0 deletions man/moment_match.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f3d4105

Please sign in to comment.