From 5e3dd969e6e2174d275cebd838e1b9f67163476a Mon Sep 17 00:00:00 2001 From: Nick Golding Date: Wed, 26 Feb 2020 16:23:36 +1100 Subject: [PATCH] make laplace agree with GRaF --- R/marginalisers.R | 46 +++++++++++++++++---------- tests/testthat/test_marginalisation.R | 7 ++-- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/R/marginalisers.R b/R/marginalisers.R index ba0f562f..40d4a161 100644 --- a/R/marginalisers.R +++ b/R/marginalisers.R @@ -220,7 +220,7 @@ laplace_approximation <- function(tolerance = 1e-6, derivs <- function(z) { y <- d0(z, reduce = FALSE) d1 <- tf$gradients(y, z)[[1]] - d2 <- tf$hessians(y, z)[[1]] + d2 <- tf$hessians(y, z)[[1]] # this won't work! list(d1, d2) } } @@ -237,13 +237,9 @@ laplace_approximation <- function(tolerance = 1e-6, # dimension of the MVN distribution n <- dim(mu)[[2]] - # randomly initialise z, and expand to batch (warm starts will need some TF - # trickery) - # here z is a *column vector* to simplify later calculations, it needs to be # transposed to a row vector before feeding into the likelihood function(s) - z_value <- add_first_dim(as_2d_array(rnorm(n))) - z <- tf$constant(z_value, dtype = tf_float()) + z <- tf$identity(mu) # Newton-Raphson parameters tol <- tf$constant(tolerance, tf_float(), shape(1)) @@ -253,11 +249,10 @@ laplace_approximation <- function(tolerance = 1e-6, # other objects a_value <- add_first_dim(as_2d_array(rep(0, n))) - a <- tf$constant(a_value, dtype = tf_float()) + a <- fl(a_value) u_value <- add_first_dim(diag(n)) - u <- tf$constant(u_value, tf_float()) - eye <- tf$constant(add_first_dim(diag(n)), - dtype = tf_float()) + u <- fl(u_value) + eye <- fl(add_first_dim(diag(n))) # match batches on everything going into the loop that will have a batch # dimension later @@ -310,11 +305,11 @@ laplace_approximation <- function(tolerance = 1e-6, s <- tf$expand_dims(s, 1L) s <- tf$expand_dims(s, 2L) a_new <- a + s * adiff - z_new <- tf$matmul(sigma, a) + mu + z_new <- tf$matmul(sigma, a_new) + mu psi(a_new, z_new, mu) } - ls_results <- gss(psiline, batch_dim) + ls_results <- gss(psiline, batch_dim, upper = 2) stepsize <- ls_results$minimum stepsize <- tf$expand_dims(stepsize, 1L) stepsize <- tf$expand_dims(stepsize, 2L) @@ -363,11 +358,17 @@ laplace_approximation <- function(tolerance = 1e-6, mat1 <- tf$matmul(rw, tf_transpose(rw)) * sigma + eye u <- tf$cholesky(mat1) + # convergence information + iter <- out[[7]] + converged <- tf$less(iter, maxiter) + # return a list of these things list(z = z, mu = mu, a = a, - u = u) + u = u, + iterations = iter, + converged = converged) } @@ -423,12 +424,23 @@ laplace_approximation <- function(tolerance = 1e-6, tf_operation = "get_element", operation_args = list("u")) + iterations <- op("iterations", + parameter_list, + tf_operation = "get_element", + operation_args = list("iterations")) + + converged <- op("converged", + parameter_list, + tf_operation = "get_element", + operation_args = list("converged")) # pull out the elements list(z = z, a = a, mu = mu, - u = u) + u = u, + iterations = iterations, + converged = converged) } @@ -470,8 +482,10 @@ laplace_approximation <- function(tolerance = 1e-6, return_list_function <- function(parameters) { - list(mean = parameters$z - parameters$mu, - sigma = chol2symm(parameters$u)) + list(mean = t(parameters$z), + sigma = chol2symm(parameters$u), + iterations = parameters$iterations, + converged = parameters$converged) } diff --git a/tests/testthat/test_marginalisation.R b/tests/testthat/test_marginalisation.R index fc0520d6..04f43382 100644 --- a/tests/testthat/test_marginalisation.R +++ b/tests/testthat/test_marginalisation.R @@ -254,8 +254,9 @@ test_that("laplace approximation converges on correct posterior", { out <- marginalise(lik, multivariate_normal(mean, sigma), laplace_approximation(diagonal_hessian = TRUE)) - theta_mu_est <- t(calculate(out$mean)[[1]]) - theta_var_est <- calculate(diag(out$sigma))[[1]] + res <- calculate(mean = t(out$mean), diag_sigma = diag(out$sigma)) + theta_mu_est <- res$mean + theta_var_est <- res$diag_sigma analytic <- cbind(mean = theta_mu, sd = sqrt(theta_var)) laplace <- cbind(mean = theta_mu_est, sd = sqrt(theta_var_est)) @@ -263,4 +264,6 @@ test_that("laplace approximation converges on correct posterior", { # compare these to within a tolerance compare_op(analytic, laplace) + # modes are right, sds are not! + })