Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes from Cole Trapnell #110

Merged
merged 14 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/PLN.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) {

## post-treatment
if (control$trace > 0) cat("\n Post-treatments...")
myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post)
myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post, control$config_optim)

if (control$trace > 0) cat("\n DONE!\n")
myPLN
Expand Down
2 changes: 1 addition & 1 deletion R/PLNLDA.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ PLNLDA <- function(formula, data, subset, weights, grouping, control = PLN_param
myLDA$optimize(grouping, args$Y, args$X, args$O, args$w, control$config_optim)

## Post-treatment: prepare LDA visualization
myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post)
myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post, control$config_optim)

if (control$trace > 0) cat("\n DONE!\n")
myLDA
Expand Down
4 changes: 2 additions & 2 deletions R/PLNLDAfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ PLNLDAfit <- R6Class(
## Post treatment --------------------
#' @description Update R2, fisher and std_err fields and visualization
#' @param config list controlling the post-treatment
postTreatment = function(grouping, responses, covariates, offsets, config) {
postTreatment = function(grouping, responses, covariates, offsets, config_post, config_optim) {
covariates <- cbind(covariates, model.matrix( ~ grouping + 0))
super$postTreatment(responses, covariates, offsets, config = config)
super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim)
rownames(private$C) <- colnames(private$C) <- colnames(responses)
colnames(private$S) <- 1:self$q
if (config$trace > 1) cat("\n\tCompute LD scores for visualization...")
Expand Down
2 changes: 1 addition & 1 deletion R/PLNPCA.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA
## Post-treatments: pseudo-R2, rearrange criteria and prepare PCA visualization
if (control$trace > 0) cat("\n Post-treatments")
config_post <- config_post_default_PLNPCA; config_post$trace <- control$trace
myPCA$postTreatment(config_post)
myPCA$postTreatment(config_post, control$config_optim)

if (control$trace > 0) cat("\n DONE!\n")
myPCA
Expand Down
4 changes: 2 additions & 2 deletions R/PLNPCAfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ PLNPCAfit <- R6Class(
#' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE.
#' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE
#' * trace integer for verbosity. should be > 1 to see output in post-treatments
postTreatment = function(responses, covariates, offsets, weights, config, nullModel) {
super$postTreatment(responses, covariates, offsets, weights, config, nullModel)
postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) {
super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel)
colnames(private$C) <- colnames(private$M) <- 1:self$q
rownames(private$C) <- colnames(responses)
self$setVisualization()
Expand Down
11 changes: 6 additions & 5 deletions R/PLNfamily-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@ PLNfamily <-
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## Post treatment --------------------
#' @description Update fields after optimization
#' @param config a list for controlling the post-treatment.
postTreatment = function(config) {
nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights)
#' @param config_post a list for controlling the post-treatment.
postTreatment = function(config_post, config_optim) {
#nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No comparison to a null Poisson model in the general post-treatment of PLN families. Is it to improve speed / because it's not required for the post-treatments ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our applications, the call to glm.fit inside of nullModelPoisson() can sometimes throw an exception, which ruins a perfectly good PLN fit! Since it seemed to us that nullModelPoisson() was something one only needed to do when (optionally) computing the approximate R2, we thought this call was superfluous.

for (model in self$models)
model$postTreatment(
self$responses,
self$covariates,
self$offsets,
self$weights,
config,
nullModel = nullModel
config_post=config_post,
config_optim=config_optim,
nullModel = NULL
)
},

Expand Down
37 changes: 20 additions & 17 deletions R/PLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ PLNfit <- R6Class(
## PRIVATE METHODS FOR VARIANCE OF THE ESTIMATORS
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

variance_variational = function(X) {
variance_variational = function(X, config = config_default_nlopt) {
## Variance of B for n data points
fisher <- Matrix::bdiag(lapply(1:self$p, function(j) {
crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, i]) %*% X
Expand Down Expand Up @@ -375,7 +375,7 @@ PLNfit <- R6Class(
#' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE.
#' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE
#' * trace integer for verbosity. should be > 1 to see output in post-treatments
postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) {
postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) {
## PARAMATERS DIMNAMES
## Set names according to those of the data matrices. If missing, use sensible defaults
if (is.null(colnames(responses)))
Expand All @@ -392,24 +392,27 @@ PLNfit <- R6Class(

## OPTIONAL POST-TREATMENT (potentially costly)
## 1. compute and store approximated R2 with Poisson-based deviance
if (config$rsquared) {
if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...")
if (config_post$rsquared) {
if(config_post$trace > 1) cat("\n\tComputing approximate R^2...")
private$approx_r2(responses, covariates, offsets, weights, nullModel)
}
## 2. compute and store matrix of standard variances for B and Omega with rough variational approximation
if (config$variational_var) {
if(config$trace > 1) cat("\n\tComputing variational estimator of the variance...")
private$variance_variational(covariates)
if (config_post$variational_var) {
if(config_post$trace > 1) cat("\n\tComputing variational estimator of the variance...")
private$variance_variational(covariates, config = config_optim)
}
## 3. Jackknife estimation of bias and variance
if (config$jackknife) {
if(config$trace > 1) cat("\n\tComputing jackknife estimator of the variance...")
private$variance_jackknife(responses, covariates, offsets, weights)
if (config_post$jackknife) {
if(config_post$trace > 1) cat("\n\tComputing jackknife estimator of the variance...")
private$variance_jackknife(responses, covariates, offsets, weights, config = config_optim)
}
## 4. Bootstrap estimation of variance
if (config$bootstrap > 0) {
if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...")
private$variance_bootstrap(responses, covariates, offsets, weights, config$bootstrap)
if (config_post$bootstrap > 0) {
if(config_post$trace > 1) {
cat("\n\tComputing bootstrap estimator of the variance...")
print (str(config_optim))
}
private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim)
}
},

Expand Down Expand Up @@ -804,11 +807,11 @@ PLNfit_fixedcov <- R6Class(
#' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated).
#' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE.
#' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE
postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) {
super$postTreatment(responses, covariates, offsets, weights, config, nullModel)
postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) {
super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel)
## 6. compute and store matrix of standard variances for B with sandwich correction approximation
if (config$sandwich_var) {
if(config$trace > 1) cat("\n\tComputing sandwich estimator of the variance...")
if (config_post$sandwich_var) {
if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...")
private$vcov_sandwich_B(responses, covariates)
}
}
Expand Down
2 changes: 1 addition & 1 deletion R/PLNmixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt
## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA
if (control$trace > 0) cat("\n Post-treatments")
config_post <- config_post_default_PLNmixture; config_post$trace <- control$trace
myPLN$postTreatment(config_post)
myPLN$postTreatment(config_post, control$config_optim)

if (control$trace > 0) cat("\n DONE!\n")
myPLN
Expand Down
5 changes: 3 additions & 2 deletions R/PLNmixturefit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ PLNmixturefit <-
## Post treatment --------------------
#' @description Update fields after optimization
#' @param config a list for controlling the post-treatment
postTreatment = function(responses, covariates, offsets, weights, config, nullModel) {
postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) {

## restoring the full design matrix (group means + covariates)
mu_k <- matrix(1, self$n, ncol = 1); colnames(mu_k) <- 'Intercept'
Expand All @@ -292,7 +292,8 @@ PLNmixturefit <-
mu_k,
offsets,
private$tau[,k_],
config,
config_post,
config_optim,
nullModel = nullModel
)
},
Expand Down
14 changes: 11 additions & 3 deletions R/PLNnetwork.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control

## Post-treatments
if (control$trace > 0) cat("\n Post-treatments")
config_post <- config_post_default_PLNnetwork; config_post$trace <- control$trace
myPLN$postTreatment(config_post)
#config_post <- config_post_default_PLNnetwork;
#config_post$trace <- control$trace
myPLN$postTreatment(control$config_post, control$config_optim)

if (control$trace > 0) cat("\n DONE!\n")
myPLN
Expand Down Expand Up @@ -85,18 +86,24 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control
#'
#' @export
PLNnetwork_param <- function(
backend = "nlopt",
backend = c("nlopt", "torch"),
trace = 1 ,
n_penalties = 30 ,
min_ratio = 0.1 ,
penalize_diagonal = TRUE ,
penalty_weights = NULL ,
config_post = list(),
config_optim = list(),
inception = NULL
) {

if (!is.null(inception)) stopifnot(isPLNfit(inception))

## post-treatment config
config_pst <- config_post_default_PLN
config_pst[names(config_post)] <- config_post
config_pst$trace <- trace

## optimization config
backend <- match.arg(backend)
stopifnot(backend %in% c("nlopt", "torch"))
Expand All @@ -123,6 +130,7 @@ PLNnetwork_param <- function(
jackknife = FALSE ,
bootstrap = 0 ,
variance = TRUE ,
config_post = config_pst ,
config_optim = config_opt ,
inception = inception ), class = "PLNmodels_param")
}
2 changes: 1 addition & 1 deletion tests/testthat/test-standard-error.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ test_that("Check that variance estimation are coherent in PLNfit", {
trace = 2
)

myPLN$postTreatment(Y, X, exp(log_O), config = config_post)
myPLN$postTreatment(Y, X, exp(log_O), config_post = config_post)

tr_variational <- sum(standard_error(myPLN, "variational")^2)
tr_bootstrap <- sum(standard_error(myPLN, "bootstrap")^2)
Expand Down