Skip to content

Commit

Permalink
Merge pull request #725 from njtierney/add-checkers-test-posteriors-723
Browse files Browse the repository at this point in the history
Add checkers test posteriors 723
  • Loading branch information
njtierney authored Oct 16, 2024
2 parents 595f7b2 + 1bf0946 commit cb14e95
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 178 deletions.
53 changes: 43 additions & 10 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
# away from truth. There's a 1/100 chance of any one of these scaled errors
# being greater than qnorm(0.99) if the sampler is correct
errors <- scaled_error(stat_draws, stat_truth)
expect_lte(max(errors), stats::qnorm(0.99))
errors
}

# sample values of greta array 'x' (which must follow a distribution), and
Expand Down Expand Up @@ -864,19 +864,52 @@ check_samples <- function(
iid_samples <- iid_function(neff)
mcmc_samples <- as.matrix(draws)

# plot
if (is.null(title)) {
distrib <- get_node(x)$distribution$distribution_name
sampler_name <- class(sampler)[1]
title <- paste(distrib, "with", sampler_name)
}
# # plot
# if (is.null(title)) {
# distrib <- get_node(x)$distribution$distribution_name
# sampler_name <- class(sampler)[1]
# title <- paste(distrib, "with", sampler_name)
# }

# stats::qqplot(mcmc_samples, iid_samples, main = title)
# graphics::abline(0, 1)

# do a formal hypothesis test
# suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
# testthat::expect_gte(stat$p.value, 0.01)

list(
mcmc_samples = mcmc_samples,
iid_samples = iid_samples,
distrib = get_node(x)$distribution$distribution_name,
sampler_name = class(sampler)[1]
)
}

qqplot_checked_samples <- function(checked_samples, title){

distrib <- checked_samples$distrib
sampler_name <- checked_samples$sampler_name
title <- paste(distrib, "with", sampler_name)

mcmc_samples <- checked_samples$mcmc_samples
iid_samples <- checked_samples$iid_samples

stats::qqplot(
x = mcmc_samples,
y = iid_samples,
main = title
)

stats::qqplot(mcmc_samples, iid_samples, main = title)
graphics::abline(0, 1)
}

## helpers for running Kolmogorov-Smirnov test for MCMC samples vs IID samples
ks_test_mcmc_vs_iid <- function(checked_samples){
# do a formal hypothesis test
suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
testthat::expect_gte(stat$p.value, 0.01)
suppressWarnings(stat <- ks.test(checked_samples$mcmc_samples,
checked_samples$iid_samples))
stat
}

## helpers for looping through optimisers
Expand Down
163 changes: 0 additions & 163 deletions tests/testthat/test_posteriors.R

This file was deleted.

34 changes: 34 additions & 0 deletions tests/testthat/test_posteriors_binomial.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Sys.setenv("RELEASE_CANDIDATE" = "true")
test_that("posterior is correct (binomial)", {
skip_if_not(check_tf_version())

skip_if_not_release()

# analytic solution to the posterior of the paramter of a binomial
# distribution, with uniform prior
n <- 100
pos <- rbinom(1, n, runif(1))
theta <- uniform(0, 1)
distribution(pos) <- binomial(n, theta)
m <- model(theta)

draws <- get_enough_draws(m, hmc(), 2000, verbose = FALSE)

samples <- as.matrix(draws)

# analytic solution to posterior is beta(1 + pos, 1 + N - pos)
shape1 <- 1 + pos
shape2 <- 1 + n - pos

# qq plot against true quantiles
quants <- (1:99) / 100
q_target <- qbeta(quants, shape1, shape2)
q_est <- quantile(samples, quants)
plot(q_target ~ q_est, main = "binomial posterior")
abline(0, 1)

n_draws <- round(coda::effectiveSize(draws))
comparison <- rbeta(n_draws, shape1, shape2)
suppressWarnings(test <- ks.test(samples, comparison))
expect_gte(test$p.value, 0.01)
})
17 changes: 17 additions & 0 deletions tests/testthat/test_posteriors_bivariate_normal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Currently takes about 30 seconds on an M1 mac
Sys.setenv("RELEASE_CANDIDATE" = "false")

test_that("samplers are unbiased for bivariate normals", {
skip_if_not(check_tf_version())

skip_if_not_release()

hmc_mvn_samples <- check_mvn_samples(sampler = hmc())
expect_lte(max(hmc_mvn_samples), stats::qnorm(0.99))

rwmh_mvn_samples <- check_mvn_samples(sampler = rwmh())
expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99))

slice_mvn_samples <- check_mvn_samples(sampler = slice())
expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99))
})
23 changes: 23 additions & 0 deletions tests/testthat/test_posteriors_chi_squared.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Sys.setenv("RELEASE_CANDIDATE" = "true")

test_that("samplers are unbiased for chi-squared", {
skip_if_not(check_tf_version())

skip_if_not_release()

df <- 5
x <- chi_squared(df)
iid <- function(n) rchisq(n, df)

chi_squared_checked <- check_samples(x = x,
iid_function = iid,
sampler = hmc())

# do the plotting
qqplot_checked_samples(chi_squared_checked)

# do a formal hypothesis test
stat <- ks_test_mcmc_vs_iid(chi_squared_checked)

expect_gte(stat$p.value, 0.01)
})
66 changes: 66 additions & 0 deletions tests/testthat/test_posteriors_geweke.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
Sys.setenv("RELEASE_CANDIDATE" = "false")

## TF1/2 - method for this test needs to be updated for TF2
## See https://github.com/greta-dev/greta/issues/720
test_that("samplers pass geweke tests", {
skip_if_not(check_tf_version())

skip_if_not_release()

# nolint start
# run geweke tests on this model:
# theta ~ normal(mu1, sd1)
# x[i] ~ normal(theta, sd2)
# for i in N
# nolint end

n <- 10
mu1 <- rnorm(1, 0, 3)
sd1 <- rlnorm(1)
sd2 <- rlnorm(1)

# prior (n draws)
p_theta <- function(n) {
rnorm(n, mu1, sd1)
}

# likelihood
p_x_bar_theta <- function(theta) {
rnorm(n, theta, sd2)
}

# define the greta model (single precision for slice sampler)
x <- as_data(rep(0, n))
greta_theta <- normal(mu1, sd1)
distribution(x) <- normal(greta_theta, sd2)
model <- model(greta_theta, precision = "single")

# run tests on all available samplers
check_geweke(
sampler = hmc(),
model = model,
data = x,
p_theta = p_theta,
p_x_bar_theta = p_x_bar_theta,
title = "HMC Geweke test"
)

check_geweke(
sampler = rwmh(),
model = model,
data = x,
p_theta = p_theta,
p_x_bar_theta = p_x_bar_theta,
warmup = 2000,
title = "RWMH Geweke test"
)

check_geweke(
sampler = slice(),
model = model,
data = x,
p_theta = p_theta,
p_x_bar_theta = p_x_bar_theta,
title = "slice sampler Geweke test"
)
})
Loading

0 comments on commit cb14e95

Please sign in to comment.