Skip to content

Commit

Permalink
Use pareto smoothing functions from posterior package (#3)
Browse files Browse the repository at this point in the history
* use posterior for pareto smoothing

* remove unnecessary TODOs

* small additions

* remove redundant line

* linting

* remove loo as dependency
  • Loading branch information
topipa committed Sep 6, 2023
1 parent 071dc66 commit ffda44f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Depends:
Imports:
abind,
checkmate,
loo (>= 2.3.1),
matrixStats (>= 0.52),
posterior,
stats
Expand Down
27 changes: 17 additions & 10 deletions R/moment_match.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ moment_match.matrix <- function(x,
target density is equal to your proposal density.")
}

lw_psis <- suppressWarnings(loo::psis(lw))
lw <- as.vector(weights(lw_psis))
k <- lw_psis$diagnostics$pareto_k
pareto_smoothed_w <- posterior::pareto_smooth(exp(lw - matrixStats::logSumExp(lw)),
tail = "right", extra_diags = TRUE, r_eff = 1
)
k <- pareto_smoothed_w$diagnostics$khat
lw <- log(as.vector(pareto_smoothed_w$x))

if (any(is.infinite(k))) {
stop("Something went wrong, and encountered infinite Pareto k values..")
}
Expand Down Expand Up @@ -153,8 +156,14 @@ moment_match.matrix <- function(x,
adapted_draws <- list("draws" = draws, "log_weights" = lw, "pareto_k" = k)
} else {
lwf <- compute_lwf(draws, lw, expectation_fun, log_expectation_fun, ...)
psisf <- suppressWarnings(loo::psis(lwf))
kf <- psisf$diagnostics$pareto_k

pareto_smoothed_wf <- apply(lwf, 2, function(x) {
posterior::pareto_smooth(exp(x),
tail = "right", extra_diags = TRUE, r_eff = 1
)
})
pareto_smoothed_wf <- do.call(mapply, c(cbind, pareto_smoothed_wf))
kf <- as.numeric(pareto_smoothed_wf$diagnostics["khat", ])

if (split) {
# prepare for split and check kfs
Expand All @@ -173,13 +182,13 @@ moment_match.matrix <- function(x,
total_mapping2 <- total_mapping
}

lwf <- compute_lwf(draws2, lw, expectation_fun, log_expectation_fun, ...)
if (ncol(lwf) > 1) {
lwf_check <- compute_lwf(draws2, lw, expectation_fun, log_expectation_fun, ...)
if (ncol(lwf_check) > 1) {
stop("Using split = TRUE is not yet supported for expectation functions
that return a matrix. As a workaround, you can wrap your function
call using apply.")
}
lwf <- as.vector(weights(psisf))
lwf <- log(as.vector(pareto_smoothed_wf$x))

if (is.null(log_prob_target_fun) && is.null(log_ratio_fun)) {
update_properties <- list(
Expand Down Expand Up @@ -335,8 +344,6 @@ moment_match.matrix <- function(x,
# )


# lw_trans_psis <- suppressWarnings(loo::psis(lw_trans))
# lw_trans <- as.vector(weights(lw_trans_psis))
lw_trans <- lw_trans - matrixStats::logSumExp(lw_trans)


Expand Down
10 changes: 7 additions & 3 deletions R/update_quantities.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ update_quantities <- function(draws, orig_log_prob_prop,
)
}

psis <- suppressWarnings(loo::psis(lw_new))
k <- psis$diagnostics$pareto_k
lw <- as.vector(weights(psis))
pareto_smoothed_w_new <- posterior::pareto_smooth(exp(lw_new - matrixStats::logSumExp(lw_new)),
tail = "right", r_eff = 1
)
k <- pareto_smoothed_w_new$diagnostics$khat
lw <- log(as.vector(pareto_smoothed_w_new$x))
# normalize log weights
lw <- lw - matrixStats::logSumExp(lw)

# gather results
list(
Expand Down
10 changes: 4 additions & 6 deletions vignettes/importance_sampling_bootstrap.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ algorithm (Paananen et al. 2020).
We will load __rstan__ for fitting our models, and the
[_iwmm package_](https://github.com/topipa/iwmm) for
the importance weighted moment matching.
We also load the __loo__ package for some helper functions.


```{r load, message=FALSE}
library("rstan")
library("loo")
library("iwmm")
seed <- 48571056
set.seed(seed)
Expand Down Expand Up @@ -114,7 +112,7 @@ Stan using the rstan package:
stanmodel <- stan_model(model_code = stancode)
# generate data
n <- as.integer(1000)
n <- as.integer(100)
x <- rnorm(n = n)
standata <- list(N = n, x = x)
Expand All @@ -123,7 +121,7 @@ stanfit <- sampling(stanmodel, data = standata, refresh = 0)
# extract posterior and log-likelihood draws
post <- as.data.frame(stanfit)[, 1:2]
S <- nrow(post)
ll <- loo::extract_log_lik(stanfit)
ll <- as.matrix(stanfit, pars = "log_lik")
```

### Compute Bootstrap means using importance sampling
Expand Down Expand Up @@ -185,7 +183,7 @@ times each observation is included in the Bootstrap sample.
```{r helper functions, warning = FALSE}
log_lik_stanfit <- function(fit, upars, parameter_name = "log_lik",
...) {
ll <- loo::extract_log_lik(fit, parameter_name, merge_chains = TRUE)
ll <- as.matrix(fit, pars = parameter_name)
S <- nrow(upars)
n <- ncol(ll)
out <- matrix(0, S, n)
Expand Down Expand Up @@ -303,7 +301,7 @@ stanfit <- sampling(stanmodel, data = standata, refresh = 0)
# extract posterior and log-likelihood draws
post <- as.data.frame(stanfit)[, 1:2]
S <- nrow(post)
ll <- loo::extract_log_lik(stanfit)
ll <- as.matrix(stanfit, pars = "log_lik")
```


Expand Down

0 comments on commit ffda44f

Please sign in to comment.