Skip to content

Commit

Permalink
make laplace agree with GRaF
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Feb 26, 2020
1 parent 60ba1b2 commit 5e3dd96
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
46 changes: 30 additions & 16 deletions R/marginalisers.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ laplace_approximation <- function(tolerance = 1e-6,
derivs <- function(z) {
y <- d0(z, reduce = FALSE)
d1 <- tf$gradients(y, z)[[1]]
d2 <- tf$hessians(y, z)[[1]]
d2 <- tf$hessians(y, z)[[1]] # this won't work!
list(d1, d2)
}
}
Expand All @@ -237,13 +237,9 @@ laplace_approximation <- function(tolerance = 1e-6,
# dimension of the MVN distribution
n <- dim(mu)[[2]]

# randomly initialise z, and expand to batch (warm starts will need some TF
# trickery)

# here z is a *column vector* to simplify later calculations, it needs to be
# transposed to a row vector before feeding into the likelihood function(s)
z_value <- add_first_dim(as_2d_array(rnorm(n)))
z <- tf$constant(z_value, dtype = tf_float())
z <- tf$identity(mu)

# Newton-Raphson parameters
tol <- tf$constant(tolerance, tf_float(), shape(1))
Expand All @@ -253,11 +249,10 @@ laplace_approximation <- function(tolerance = 1e-6,

# other objects
a_value <- add_first_dim(as_2d_array(rep(0, n)))
a <- tf$constant(a_value, dtype = tf_float())
a <- fl(a_value)
u_value <- add_first_dim(diag(n))
u <- tf$constant(u_value, tf_float())
eye <- tf$constant(add_first_dim(diag(n)),
dtype = tf_float())
u <- fl(u_value)
eye <- fl(add_first_dim(diag(n)))

# match batches on everything going into the loop that will have a batch
# dimension later
Expand Down Expand Up @@ -310,11 +305,11 @@ laplace_approximation <- function(tolerance = 1e-6,
s <- tf$expand_dims(s, 1L)
s <- tf$expand_dims(s, 2L)
a_new <- a + s * adiff
z_new <- tf$matmul(sigma, a) + mu
z_new <- tf$matmul(sigma, a_new) + mu
psi(a_new, z_new, mu)
}

ls_results <- gss(psiline, batch_dim)
ls_results <- gss(psiline, batch_dim, upper = 2)
stepsize <- ls_results$minimum
stepsize <- tf$expand_dims(stepsize, 1L)
stepsize <- tf$expand_dims(stepsize, 2L)
Expand Down Expand Up @@ -363,11 +358,17 @@ laplace_approximation <- function(tolerance = 1e-6,
mat1 <- tf$matmul(rw, tf_transpose(rw)) * sigma + eye
u <- tf$cholesky(mat1)

# convergence information
iter <- out[[7]]
converged <- tf$less(iter, maxiter)

# return a list of these things
list(z = z,
mu = mu,
a = a,
u = u)
u = u,
iterations = iter,
converged = converged)

}

Expand Down Expand Up @@ -423,12 +424,23 @@ laplace_approximation <- function(tolerance = 1e-6,
tf_operation = "get_element",
operation_args = list("u"))

iterations <- op("iterations",
parameter_list,
tf_operation = "get_element",
operation_args = list("iterations"))

converged <- op("converged",
parameter_list,
tf_operation = "get_element",
operation_args = list("converged"))

# pull out the elements
list(z = z,
a = a,
mu = mu,
u = u)
u = u,
iterations = iterations,
converged = converged)

}

Expand Down Expand Up @@ -470,8 +482,10 @@ laplace_approximation <- function(tolerance = 1e-6,

return_list_function <- function(parameters) {

list(mean = parameters$z - parameters$mu,
sigma = chol2symm(parameters$u))
list(mean = t(parameters$z),
sigma = chol2symm(parameters$u),
iterations = parameters$iterations,
converged = parameters$converged)

}

Expand Down
7 changes: 5 additions & 2 deletions tests/testthat/test_marginalisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,16 @@ test_that("laplace approximation converges on correct posterior", {
out <- marginalise(lik,
multivariate_normal(mean, sigma),
laplace_approximation(diagonal_hessian = TRUE))
theta_mu_est <- t(calculate(out$mean)[[1]])
theta_var_est <- calculate(diag(out$sigma))[[1]]
res <- calculate(mean = t(out$mean), diag_sigma = diag(out$sigma))
theta_mu_est <- res$mean
theta_var_est <- res$diag_sigma

analytic <- cbind(mean = theta_mu, sd = sqrt(theta_var))
laplace <- cbind(mean = theta_mu_est, sd = sqrt(theta_var_est))

# compare these to within a tolerance
compare_op(analytic, laplace)

# modes are right, sds are not!

})

0 comments on commit 5e3dd96

Please sign in to comment.