Skip to content

Commit

Permalink
implement process options
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored and seabbs committed Jan 27, 2023
1 parent 619a1a4 commit 9167cb4
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
11 changes: 8 additions & 3 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
#' @param reported_cases A data frame of confirmed cases (confirm) by date
#' (date). confirm must be integer and date must be in date format.
#'
#' @param process_model A character string that defines what is being
#' modelled: "infections", "growth" or "R" (default). If ' set to "R",
#' a generation time distribution needs to be defined via the `generation_time`
#' argument.
#'
#' @param generation_time A call to `generation_time_opts()` defining the
#' generation time distribution used. For backwards compatibility a list of
#' summary parameters can also be passed.
Expand Down Expand Up @@ -218,7 +223,7 @@
#' options(old_opts)
#' }
estimate_infections <- function(reported_cases,
model = "R",
process_opts = process_opts(),
generation_time = generation_time_opts(),
delays = delay_opts(),
truncation = trunc_opts(),
Expand Down Expand Up @@ -288,6 +293,7 @@ estimate_infections <- function(reported_cases,
# Define stan model parameters
data <- create_stan_data(
reported_cases = reported_cases,
process_opts = process_opts,
generation_time = generation_time,
delays = delays,
truncation = truncation,
Expand All @@ -296,8 +302,7 @@ estimate_infections <- function(reported_cases,
obs = obs,
backcalc = backcalc,
shifted_cases = shifted_cases$confirm,
horizon = horizon,
process_model = process_model
horizon = horizon
)

# Set up default settings
Expand Down
94 changes: 90 additions & 4 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ generation_time_opts <- function(..., disease, source, max = 15L,
names(gt) <- paste0("gt_", names(gt))

return(gt)
}

#' Delay Distribution Options
#'
Expand Down Expand Up @@ -211,7 +210,7 @@ trunc_opts <- function(dist = list()) {

#' Time-Varying Reproduction Number Options
#'
#' @description `r lifecycle::badge("stable")`
#' @description `r lifecycle::badge("deprecated")`
#' Defines a list specifying the optional arguments for the time-varying
#' reproduction number. Custom settings can be supplied which override the
#' defaults.
Expand Down Expand Up @@ -249,6 +248,7 @@ trunc_opts <- function(dist = list()) {
#'
#' @return A list of settings defining the time-varying reproduction number.
#' @author Sam Abbott

#' @inheritParams create_future_rt
#' @export
#' @examples
Expand All @@ -267,6 +267,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
future = "latest",
gp_on = "R_t-1",
pop = 0) {
stop("rt_opts is deprecated - use process_opts instead")
rt <- list(
prior = prior,
use_rt = use_rt,
Expand All @@ -288,9 +289,93 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
return(rt)
}

#' Back Calculation Options
#' Process model optionss
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the process mode.
#' Custom settings can be supplied which override the defaults.
#' @param prior List containing named numeric elements "mean" and "sd". The mean and
#' standard deviation of the log normal Rt prior. Defaults to mean of 1 and standard
#' deviation of 1.
#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate infections
#' and hence reported cases.
#' @param rw Numeric step size of the random walk, defaults to 0. To specify a weekly random
#' walk set `rw = 7`. For more custom break point settings consider passing in a `breakpoints`
#' variable as outlined in the next section.
#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be used if present
#' as a `breakpoint` variable in the input data. Break points should be defined as 1 if present
#' and otherwise 0. By default breakpoints are fit jointly with a global non-parametric effect
#' and so represent a conservative estimate of break point changes (alter this by setting `gp = NULL`).
#' @param pop Integer, defaults to 0. Susceptible population initially present. Used to adjust
#' Rt estimates when otherwise fixed based on the proportion of the population that is
#' susceptible. When set to 0 no population adjustment is done.
#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the Gaussian process,
#' if in use, should be applied to Rt. Currently supported options are applying the Gaussian
#' process to the last estimated Rt (i.e Rt = Rt-1 * GP), and applying the Gaussian process to
#' a global mean (i.e Rt = R0 * GP). Both should produced comparable results when data is not
#' sparse but the method relying on a global mean will revert to this for real time estimates,
#' which may not be desirable.
#' @return A list of settings defining the time-varying reproduction number
#' @inheritParams create_future_rt
#' @export
#' @examples
#' # default settings
#' rt_opts()
#'
#' # add a custom length scale
#' rt_opts(prior = list(mean = 2, sd = 1))
#'
#' # add a weekly random walk
#' rt_opts(rw = 7)
#' @importFrom data.table fcase
process_opts <- function(model = "R",
prior_mean = data.table::fcase(
model == "R", list(mean = 1, sd = 1),
model == "growth", list(mean = 0, sd = 1),
model == "infections", NULL
),
prior_t = NULL,
rw = 0,
use_breakpoints = TRUE,
future = "latest",
stationary = FALSE
pop = 0) {

## check
model_choices <- c("infections", "growth", "R")
process_model <- match.arg(process_model, choices = model_choices)
process_model <- which(process_model == model_choices) - 1

if (!(xor(is.null(prior_mean), is.null(prior_t)))) {
stop("Either 'prior_mean' or 'prior_t' must be set to NULL")
}
process <- list(
process_model = process_model,
prior_mean = prior_mean,
prior_t = prior_t,
rw = rw,
use_breakpoints = use_breakpoints,
future = future,
stationary = stationary,
pop = pop,
)

# replace default settings with those specified by user
if (process$rw > 0) {
process$use_breakpoints <- TRUE
}

if (!is.null(prior_mean) &&
!("mean" %in% names(process$prior) &&
"sd" %in% names(process$prior))) {
stop("prior must have both a mean and sd specified")
}
return(process)
}

#' Back Calculation Options
#'
#' @description `r lifecycle::badge("deprecated")`
#' Defines a list specifying the optional arguments for the back calculation
#' of cases. Only used if `rt = NULL`.
#'
Expand Down Expand Up @@ -323,7 +408,8 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
#' # default settings
#' backcalc_opts()
backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) {
backcalc <- list(
stop("backcalc_opts is deprecated - use process_opts instead")
backcalc <- list(
prior = match.arg(prior, choices = c("reports", "none", "infections")),
prior_window = prior_window,
rt_window = as.integer(rt_window)
Expand Down

0 comments on commit 9167cb4

Please sign in to comment.