-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from Cole Trapnell #110
Changes from 1 commit
0e56cd7
4dd7451
635dd22
0d5b9a7
6a745c5
057fd3b
423775b
f801ca6
a0a43e4
31b6153
ed7c181
926c77c
02e3501
227c6da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,21 +80,27 @@ PLNfit <- R6Class( | |
|
||
torch_vloglik = function(data, params) { | ||
S2 <- torch_square(params$S) | ||
Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y))) + as.numeric( | ||
.5 * torch_logdet(params$Omega) + | ||
torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - | ||
.5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) | ||
) | ||
attr(Ji, "weights") <- as.numeric(data$w) | ||
|
||
Ji_tmp = .5 * torch_logdet(params$Omega) + | ||
torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - | ||
.5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) | ||
Ji_tmp = Ji_tmp$cpu() | ||
Ji_tmp = as.numeric(Ji_tmp) | ||
Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y$cpu()))) + Ji_tmp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. Perhaps it would be better to defer that so we can use .logfactorial_torch() instead |
||
|
||
attr(Ji, "weights") <- as.numeric(data$w$cpu()) | ||
Ji | ||
}, | ||
|
||
#' @import torch | ||
torch_optimize = function(data, params, config) { | ||
|
||
#config$device = "mps" | ||
if (config$trace > 1) | ||
message (paste("optimizing with device: ", config$device)) | ||
## Conversion of data and parameters to torch tensors (pointers) | ||
data <- lapply(data, torch_tensor) # list with Y, X, O, w | ||
params <- lapply(params, torch_tensor, requires_grad = TRUE) # list with B, M, S | ||
data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) # list with Y, X, O, w | ||
params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) # list with B, M, S | ||
|
||
## Initialize optimizer | ||
optimizer <- switch(config$algorithm, | ||
|
@@ -111,11 +117,14 @@ PLNfit <- R6Class( | |
batch_size <- floor(self$n/num_batch) | ||
|
||
objective <- double(length = config$num_epoch + 1) | ||
#B_old = optimizer$param_groups[[1]]$params$B$clone() | ||
for (iterate in 1:num_epoch) { | ||
B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) | ||
|
||
#B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) | ||
B_old = optimizer$param_groups[[1]]$params$B$clone() | ||
# rearrange the data each epoch | ||
permute <- torch::torch_randperm(self$n) + 1L | ||
#permute <- torch::torch_randperm(self$n, device = "cpu") + 1L | ||
permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) | ||
|
||
for (batch_idx in 1:num_batch) { | ||
# here index is a vector of the indices in the batch | ||
index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] | ||
|
@@ -129,14 +138,21 @@ PLNfit <- R6Class( | |
|
||
## assess convergence | ||
objective[iterate + 1] <- loss$item() | ||
B_new <- as.numeric(optimizer$param_groups[[1]]$params$B) | ||
B_new <- optimizer$param_groups[[1]]$params$B | ||
delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) | ||
delta_x <- sum(abs(B_old - B_new))/sum(abs(B_new)) | ||
delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) | ||
|
||
#print (delta_f) | ||
#print (delta_x) | ||
delta_x = delta_x$cpu() | ||
#print (delta_x) | ||
delta_x = as.matrix(delta_x) | ||
#print (delta_x) | ||
|
||
## display progress | ||
if (config$trace > 1 && (iterate %% 50 == 0)) | ||
cat('\niteration: ', iterate, 'objective', objective[iterate + 1], | ||
'delta_f' , round(delta_f, 6), 'delta_x', ro<und(delta_x, 6)) | ||
'delta_f' , round(delta_f, 6), 'delta_x', round(delta_x, 6)) | ||
|
||
## Check for convergence | ||
if (delta_f < config$ftol_rel) status <- 3 | ||
|
@@ -152,7 +168,10 @@ PLNfit <- R6Class( | |
params$Z <- data$O + params$M + torch_matmul(data$X, params$B) | ||
params$A <- torch_exp(params$Z + torch_pow(params$S, 2)/2) | ||
|
||
out <- lapply(params, as.matrix) | ||
out <- lapply(params, function(x) { | ||
x = x$cpu() | ||
as.matrix(x)} | ||
) | ||
out$Ji <- private$torch_vloglik(data, params) | ||
out$monitoring <- list( | ||
objective = objective, | ||
|
@@ -199,7 +218,7 @@ PLNfit <- R6Class( | |
}, | ||
|
||
variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) { | ||
jacks <- future.apply::future_lapply(seq_len(self$n), function(i) { | ||
jacks <- lapply(seq_len(self$n), function(i) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no strong opinion on this one as @jchiquet was the one who wrote this part but is there a reason prefer lapply to future_lapply (one less dependency ? simpler to use ?). A nice thing about future_lapply is that it is backend-agnostic and can be used for several parallalelization paradigms. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, yes, agree it would be wonderful to keep this. The issue is that on machines that use OpenBLAS with a multithreaded backend, using future can deadlock the session. A workaround is to wrap calls to future with something like this:
Then you do work with future and then:
We didn't add this because we didn't want to add a new dependency on RhpcBLASctl to the package, but you could do if you want to be able to do linear algebra inside of functions called by future |
||
data <- list(Y = Y[-i, , drop = FALSE], | ||
X = X[-i, , drop = FALSE], | ||
O = O[-i, , drop = FALSE], | ||
|
@@ -209,7 +228,7 @@ PLNfit <- R6Class( | |
config = config) | ||
optim_out <- do.call(private$optimizer$main, args) | ||
optim_out[c("B", "Omega")] | ||
}, future.seed = TRUE) | ||
}) | ||
|
||
B_jack <- jacks %>% map("B") %>% reduce(`+`) / self$n | ||
var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% | ||
|
@@ -228,17 +247,28 @@ PLNfit <- R6Class( | |
|
||
variance_bootstrap = function(Y, X, O, w, n_resamples = 100, config = config_default_nlopt) { | ||
resamples <- replicate(n_resamples, sample.int(self$n, replace = TRUE), simplify = FALSE) | ||
boots <- future.apply::future_lapply(resamples, function(resample) { | ||
boots <- lapply(resamples, function(resample) { | ||
data <- list(Y = Y[resample, , drop = FALSE], | ||
X = X[resample, , drop = FALSE], | ||
O = O[resample, , drop = FALSE], | ||
w = w[resample]) | ||
#print (config$torch_device) | ||
#print (config) | ||
if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be better. Would also be good to be able to specify the torch device (e.g. "mps", "cuda", etc) |
||
data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w | ||
|
||
#print (data$Y$device) | ||
|
||
args <- list(data = data, | ||
params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]), | ||
config = config) | ||
if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as previous comment |
||
args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S | ||
|
||
optim_out <- do.call(private$optimizer$main, args) | ||
#print (optim_out) | ||
optim_out[c("B", "Omega", "monitoring")] | ||
}, future.seed = TRUE) | ||
}) | ||
|
||
B_boots <- boots %>% map("B") %>% reduce(`+`) / n_resamples | ||
attr(private$B, "variance_bootstrap") <- | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,8 @@ config_default_torch <- | |
step_sizes = c(1e-3, 50), | ||
etas = c(0.5, 1.2), | ||
centered = FALSE, | ||
trace = 1 | ||
trace = 1, | ||
device = "cpu" | ||
) | ||
|
||
config_post_default_PLN <- | ||
|
@@ -107,6 +108,11 @@ trace <- function(x) sum(diag(x)) | |
x | ||
} | ||
|
||
.logfactorial_torch <- function(n){ | ||
n[n == 0] <- 1 ## 0! = 1! | ||
n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + log(pi)/2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For .logfactorial_torch(), shouldn't the final term log(pi)/2 be torch_log(pi)/2 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, though not actually sure what would happen under the hood here in terms of evaluation? |
||
} | ||
|
||
.logfactorial <- function(n) { # Ramanujan's formula | ||
n[n == 0] <- 1 ## 0! = 1! | ||
n*log(n) - n + log(8*n^3 + 4*n^2 + n + 1/30)/6 + log(pi)/2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the additional comma if its the last element of the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops :)