Skip to content

Commit

Permalink
get laplace variances working
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Feb 26, 2020
1 parent 5e3dd96 commit 2bae9ab
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
43 changes: 29 additions & 14 deletions R/marginalisers.R
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,19 @@ laplace_approximation <- function(tolerance = 1e-6,
w <- -d2
rw <- sqrt(w)

# approximate posterior covariance & cholesky factor
# approximate posterior covariance
# do we need the eye?
mat1 <- tf$matmul(rw, tf_transpose(rw)) * sigma + eye
u <- tf$cholesky(mat1)
l <- tf$cholesky(mat1)
v <- tf$linalg$triangular_solve(matrix = l,
rhs = sigma * rw,
lower = TRUE,
adjoint = TRUE)
covar <- sigma - tf$linalg$matmul(v, v, transpose_b = TRUE)

# log-determinant of l
l_diag <- tf$matrix_diag_part(l)
logdet <- tf_sum(tf$log(l_diag))

# convergence information
iter <- out[[7]]
Expand All @@ -366,7 +376,8 @@ laplace_approximation <- function(tolerance = 1e-6,
list(z = z,
mu = mu,
a = a,
u = u,
logdet = logdet,
covar = covar,
iterations = iter,
converged = converged)

Expand Down Expand Up @@ -418,11 +429,16 @@ laplace_approximation <- function(tolerance = 1e-6,
tf_operation = "get_element",
operation_args = list("mu"))

u <- op("chol_sigma",
parameter_list,
dim = dim(sigma),
tf_operation = "get_element",
operation_args = list("u"))
logdet <- op("log determinant",
parameter_list,
tf_operation = "get_element",
operation_args = list("logdet"))

covar <- op("covar",
parameter_list,
dim = dim(sigma),
tf_operation = "get_element",
operation_args = list("covar"))

iterations <- op("iterations",
parameter_list,
Expand All @@ -438,7 +454,8 @@ laplace_approximation <- function(tolerance = 1e-6,
list(z = z,
a = a,
mu = mu,
u = u,
logdet = logdet,
covar = covar,
iterations = iterations,
converged = converged)

Expand Down Expand Up @@ -467,14 +484,12 @@ laplace_approximation <- function(tolerance = 1e-6,
}

mu <- parameters$mu
u <- parameters$u
logdet <- parameters$logdet
z <- parameters$z
a <- parameters$a

# the approximate marginal conditional posterior
u_diag <- tf$matrix_diag_part(u)
logdet <- tf_sum(tf$log(u_diag))
nmcp <- psi(a, z, mu) + tf$squeeze(logdet, 1)
nmcp <- psi(a, z, mu) + tf$squeeze(u_logdet, 1)

-nmcp

Expand All @@ -483,7 +498,7 @@ laplace_approximation <- function(tolerance = 1e-6,
return_list_function <- function(parameters) {

list(mean = t(parameters$z),
sigma = chol2symm(parameters$u),
sigma = parameters$covar,
iterations = parameters$iterations,
converged = parameters$converged)

Expand Down
9 changes: 3 additions & 6 deletions tests/testthat/test_marginalisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,12 @@ test_that("laplace approximation converges on correct posterior", {
theta_mu <- (y * obs_prec + mu * prec) * theta_var
theta_sd <- sqrt(theta_var)

# Laplace solution:
lik <- function(theta) {
distribution(y) <- normal(t(theta), obs_sd)
}

# mock up as a multivariate normal distribution
mean <- ones(1, 8) * mu
sigma <- diag(8) * sd ^ 2
lik <- function(theta) {
distribution(y) <- normal(t(theta), obs_sd)
}
out <- marginalise(lik,
multivariate_normal(mean, sigma),
laplace_approximation(diagonal_hessian = TRUE))
Expand All @@ -264,6 +262,5 @@ test_that("laplace approximation converges on correct posterior", {
# compare these to within a tolerance
compare_op(analytic, laplace)

# modes are right, sds are not!

})

0 comments on commit 2bae9ab

Please sign in to comment.