Skip to content

Commit

Permalink
Merge pull request #314 from n-kall/doc-updates
Browse files Browse the repository at this point in the history
Pareto-smoothing updates
  • Loading branch information
paul-buerkner authored Nov 23, 2023
2 parents e420f13 + b062ce4 commit adf3813
Show file tree
Hide file tree
Showing 26 changed files with 285 additions and 65 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

### Enhancements

* Add `pareto_smooth` option to `weight_draws`, to Pareto smooth
weights before adding to a draws object.
* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.


# posterior 1.5.0

### Enhancements
Expand Down
2 changes: 2 additions & 0 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#' | [mcse_mean()] | Monte Carlo standard error for the mean |
#' | [mcse_quantile()] | Monte Carlo standard error for quantiles |
#' | [mcse_sd()] | Monte Carlo standard error for standard deviations |
#' | [pareto_khat()] | Pareto khat diagnostic for tail(s) |
#' | [pareto_diags()] | Additional diagnostics related to Pareto khat |
#' | [rhat_basic()] | Basic version of Rhat |
#' | [rhat()] | Improved, rank-based version of Rhat |
#' | [rhat_nested()] | Rhat for use with many short chains |
Expand Down
122 changes: 87 additions & 35 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
#' the number of fractional moments that is useful for convergence
#' diagnostics. For further details see Vehtari et al. (2022).
#'
#' @family diagnostics
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return `khat` estimated Generalized Pareto Distribution shape parameter k
#'
#' @seealso [`pareto_diags`] for additional related diagnostics, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_khat(mu)
Expand All @@ -25,6 +29,7 @@ pareto_khat.default <- function(x,
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
are_log_weights = FALSE,
...) {
smoothed <- pareto_smooth.default(
x,
Expand All @@ -34,6 +39,7 @@ pareto_khat.default <- function(x,
verbose = verbose,
return_k = TRUE,
smooth_draws = FALSE,
are_log_weights = are_log_weights,
...)
return(smoothed$diagnostics)
}
Expand Down Expand Up @@ -65,6 +71,7 @@ pareto_khat.rvar <- function(x, ...) {
#' replacing tail draws by order statistics of a generalized Pareto
#' distribution fit to the tail(s).
#'
#' @family diagnostics
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
Expand Down Expand Up @@ -100,6 +107,8 @@ pareto_khat.rvar <- function(x, ...) {
#' when the sample size is increased, compared to the central limit
#' theorem convergence rate. See Appendix B in Vehtari et al. (2022).
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_diags(mu)
Expand All @@ -113,11 +122,12 @@ pareto_diags <- function(x, ...) UseMethod("pareto_diags")
#' @rdname pareto_diags
#' @export
pareto_diags.default <- function(x,
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
...) {
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
are_log_weights = FALSE,
...) {

smoothed <- pareto_smooth.default(
x,
Expand All @@ -128,6 +138,7 @@ pareto_diags.default <- function(x,
extra_diags = TRUE,
verbose = verbose,
smooth_draws = FALSE,
are_log_weights = FALSE,
...)

return(smoothed$diagnostics)
Expand Down Expand Up @@ -189,6 +200,8 @@ pareto_diags.rvar <- function(x, ...) {
#' Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_smooth(mu)
Expand Down Expand Up @@ -225,8 +238,8 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
)
}
out <- list(
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
)
} else {
out <- rvar(apply(draws_diags, margins, function(x) x[[1]]), nchains = nchains(x))
Expand All @@ -238,25 +251,36 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
#' @export
pareto_smooth.default <- function(x,
tail = c("both", "right", "left"),
r_eff = NULL,
r_eff = 1,
ndraws_tail = NULL,
return_k = TRUE,
extra_diags = FALSE,
verbose = FALSE,
are_log_weights = FALSE,
...) {

checkmate::assert_number(ndraws_tail, null.ok = TRUE)
checkmate::assert_number(r_eff, null.ok = TRUE)
checkmate::assert_logical(extra_diags)
checkmate::assert_logical(return_k)
checkmate::assert_logical(verbose)
checkmate::expect_numeric(ndraws_tail, null.ok = TRUE)
checkmate::expect_numeric(r_eff, null.ok = TRUE)
extra_diags <- as_one_logical(extra_diags)
return_k <- as_one_logical(return_k)
verbose <- as_one_logical(verbose)
are_log_weights <- as_one_logical(are_log_weights)

# check for infinite or na values
if (should_return_NA(x)) {
warning_no_call("Input contains infinite or NA values, Pareto smoothing not performed.")
return(list(x = x, diagnostics = NA_real_))
warning_no_call("Input contains infinite or NA values, or is constant. Fitting of generalized Pareto distribution not performed.")
if (!return_k) {
out <- x
} else {
out <- list(x = x, diagnostics = NA_real_)
}
return(out)
}

if (are_log_weights) {
tail <- "right"
}

tail <- match.arg(tail)
S <- length(x)

Expand Down Expand Up @@ -290,6 +314,7 @@ pareto_smooth.default <- function(x,
x,
ndraws_tail = ndraws_tail,
tail = "left",
are_log_weights = are_log_weights,
...
)
left_k <- smoothed$k
Expand All @@ -299,12 +324,14 @@ pareto_smooth.default <- function(x,
x = smoothed$x,
ndraws_tail = ndraws_tail,
tail = "right",
are_log_weights = are_log_weights,
...
)
right_k <- smoothed$k

k <- max(left_k, right_k)
x <- smoothed$x

} else {

smoothed <- .pareto_smooth_tail(
Expand All @@ -326,10 +353,11 @@ pareto_smooth.default <- function(x,

if (verbose) {
if (!extra_diags) {
diags_list <- .pareto_smooth_extra_diags(diags_list$khat, length(x))
diags_list <- c(diags_list, .pareto_smooth_extra_diags(diags_list$khat, length(x)))
}
pareto_k_diagmsg(
diags = diags_list
diags = diags_list,
are_weights = are_log_weights
)
}

Expand All @@ -349,26 +377,32 @@ pareto_smooth.default <- function(x,
ndraws_tail,
smooth_draws = TRUE,
tail = c("right", "left"),
are_log_weights = FALSE,
...
) {

if (are_log_weights) {
# shift log values for safe exponentiation
x <- x - max(x)
}

tail <- match.arg(tail)

S <- length(x)
tail_ids <- seq(S - ndraws_tail + 1, S)


if (tail == "left") {
x <- -x
}

ord <- sort.int(x, index.return = TRUE)
draws_tail <- ord$x[tail_ids]
cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values

cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values

max_tail <- max(draws_tail)
min_tail <- min(draws_tail)

if (ndraws_tail >= 5) {
ord <- sort.int(x, index.return = TRUE)
if (abs(max_tail - min_tail) < .Machine$double.eps / 100) {
Expand All @@ -380,12 +414,19 @@ pareto_smooth.default <- function(x,
k <- NA
} else {
# save time not sorting since x already sorted
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE)
if (are_log_weights) {
draws_tail <- exp(draws_tail)
cutoff <- exp(cutoff)
}
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...)
k <- fit$k
sigma <- fit$sigma
if (is.finite(k) && smooth_draws) {
p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail
smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma)
if (are_log_weights) {
smoothed <- log(smoothed)
}
} else {
smoothed <- NULL
}
Expand Down Expand Up @@ -445,11 +486,11 @@ pareto_smooth.default <- function(x,
#' @noRd
ps_min_ss <- function(k, ...) {
if (k < 1) {
out <- 10^(1 / (1 - max(0, k)))
out <- 10^(1 / (1 - max(0, k)))
} else {
out <- Inf
out <- Inf
}
out
out
}


Expand Down Expand Up @@ -506,27 +547,38 @@ ps_tail_length <- function(S, r_eff, ...) {
#'
#' Given S and scalar and k, form a diagnostic message string
#' @param diags (numeric) named vector of diagnostic values
#' @param are_weights (logical) are the diagnostics for weights
#' @param ... unused
#' @return diagnostic message
#' @noRd
pareto_k_diagmsg <- function(diags, ...) {
pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) {
khat <- diags$khat
min_ss <- diags$min_ss
khat_threshold <- diags$khat_threshold
convergence_rate <- diags$convergence_rate
msg <- NULL
if (khat > 1) {
msg <- paste0(msg,'All estimates are unreliable. If the distribution of ratios is bounded,\n',
'further draws may improve the estimates, but it is not possible to predict\n',
'whether any feasible sample size is sufficient.')
} else {
if (khat > khat_threshold) {
msg <- paste0(msg, 'S is too small, and sample size larger than ', round(min_ss, 0), ' is needed for reliable results.\n')

if (!are_weights) {

if (khat > 1) {
msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n",
"further draws may improve the estimates, but it is not possible to predict\n",
"whether any feasible sample size is sufficient.")
} else {
msg <- paste0(msg, 'To halve the RMSE, approximately ', round(2^(2/convergence_rate),1), ' times bigger S is needed.\n')
if (khat > khat_threshold) {
msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
} else {
msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n")
}
if (khat > 0.7) {
msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n")
}
}
if (khat > 0.7) {
msg <- paste0(msg, 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n')

} else {

if (khat > khat_threshold || khat > 0.7) {
msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
}
}
message(msg)
Expand Down
Loading

0 comments on commit adf3813

Please sign in to comment.