Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes from Cole Trapnell #110

Merged
merged 14 commits into from
Nov 28, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions R/PLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ PLNfit <- R6Class(
## PRIVATE TORCH METHODS FOR OPTIMIZATION
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
torch_elbo = function(data, params, index=torch_tensor(1:self$n)) {
#print (index)
#print (params$S)
S2 <- torch_square(params$S[index])
Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B)
res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) +
Expand Down Expand Up @@ -140,11 +142,12 @@ PLNfit <- R6Class(
objective[iterate + 1] <- loss$item()
B_new <- optimizer$param_groups[[1]]$params$B
delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1])
#delta_x = 0
delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new))
delta_x = delta_x$cpu()

#print (delta_f)
#print (delta_x)
delta_x = delta_x$cpu()
#print (delta_x)
delta_x = as.matrix(delta_x)
#print (delta_x)
Expand All @@ -156,7 +159,7 @@ PLNfit <- R6Class(

## Check for convergence
if (delta_f < config$ftol_rel) status <- 3
if (delta_x < config$xtol_rel) status <- 4
#if (delta_x < config$xtol_rel) status <- 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to remove the convergence check on the parameter values and to keep only the one on the ELBO value ? This will speed up the algorithm (less conditions to satisfy) but may cause the nlopt and torch implementations diverge in the result they produce (not a bad thing per se, but something we need be aware of).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I defer to you on that - we have been generally happy with larger default values of ftol_rel, triggering convergence on status 3 most of the time.

if (status %in% c(3,4)) {
objective <- objective[1:iterate + 1]
break
Expand Down Expand Up @@ -217,6 +220,54 @@ PLNfit <- R6Class(
invisible(list(var_B = var_B, var_Omega = var_Omega))
},

compute_vcov_from_resamples = function(resamples){
# compute the covariance of the parameters
get_cov_mat = function(data, cell_group) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be mistaken, but get_cov_mat() appears to be defined but never used anywhere in compute_vcov_from_resamples(). Is it necessary ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vestigial code from debugging - can be removed


cov_matrix = cov(data)
rownames(cov_matrix) = paste0(cell_group, "_", rownames(cov_matrix))
colnames(cov_matrix) = paste0(cell_group, "_", colnames(cov_matrix))
return(cov_matrix)
}


B_list = resamples %>% map("B")
#print (B_list)
vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){
param_ests_for_col = B_list %>% map(~.x[, B_col])
param_ests_for_col = do.call(rbind, param_ests_for_col)
print (param_ests_for_col)
row_vcov = cov(param_ests_for_col)
})
#print ("vcov blocks")
#print (vcov_B)

#B_vcov <- resamples %>% map("B") %>% map(~( . )) %>% reduce(cov)

#var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>%
# `dimnames<-`(dimnames(private$B))
#B_hat <- private$B[,] ## strips attributes while preserving names

vcov_B = Matrix::bdiag(vcov_B) %>% as.matrix()

rownames(vcov_B) <- colnames(vcov_B) <-
expand.grid(covariates = rownames(private$B),
responses = colnames(private$B)) %>% rev() %>%
## Hack to make sure that species is first and varies slowest
apply(1, paste0, collapse = "_")

#print (pheatmap::pheatmap(vcov_B, cluster_rows=FALSE, cluster_cols=FALSE))


#names = lapply(bootstrapped_df$cov_mat, function(m){ colnames(m)}) %>% unlist()
#rownames(bootstrapped_vhat) = names
#colnames(bootstrapped_vhat) = names

vcov_B = methods::as(vcov_B, "dgCMatrix")

return(vcov_B)
},

variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) {
jacks <- lapply(seq_len(self$n), function(i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no strong opinion on this one as @jchiquet was the one who wrote this part but is there a reason prefer lapply to future_lapply (one less dependency ? simpler to use ?). A nice thing about future_lapply is that it is backend-agnostic and can be used for several parallalelization paradigms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, agree it would be wonderful to keep this. The issue is that on machines that use OpenBLAS with a multithreaded backend, using future can deadlock the session. A workaround is to wrap calls to future with something like this:

old_omp_num_threads = as.numeric(Sys.getenv("OMP_NUM_THREADS"))
if (is.na(old_omp_num_threads)){
   old_omp_num_threads = 1
}
RhpcBLASctl::omp_set_num_threads(1)

old_blas_num_threads = as.numeric(Sys.getenv("OPENBLAS_NUM_THREADS"))
if (is.na(old_omp_num_threads)){
  old_blas_num_threads = 1
}
RhpcBLASctl::blas_set_num_threads(1)

Then you do work with future and then:

RhpcBLASctl::omp_set_num_threads(old_omp_num_threads)
RhpcBLASctl::blas_set_num_threads(old_blas_num_threads)

We didn't add this because we didn't want to add a new dependency on RhpcBLASctl to the package, but you could do if you want to be able to do linear algebra inside of functions called by future

data <- list(Y = Y[-i, , drop = FALSE],
Expand All @@ -237,6 +288,9 @@ PLNfit <- R6Class(
attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat)
attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack

vcov_boots = private$compute_vcov_from_resamples(boots)
attr(private$B, "vcov_jackknife") <- vcov_boots

Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n
var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>%
`dimnames<-`(dimnames(private$Omega))
Expand Down Expand Up @@ -275,6 +329,9 @@ PLNfit <- R6Class(
boots %>% map("B") %>% map(~( (. - B_boots)^2)) %>% reduce(`+`) %>%
`dimnames<-`(dimnames(private$B)) / n_resamples

vcov_boots = private$compute_vcov_from_resamples(boots)
attr(private$B, "vcov_bootstrap") <- vcov_boots

Omega_boots <- boots %>% map("Omega") %>% reduce(`+`) / n_resamples
attr(private$Omega, "variance_bootstrap") <-
boots %>% map("Omega") %>% map(~( (. - Omega_boots)^2)) %>% reduce(`+`) %>%
Expand Down