From 8fe772a4a9cdebce85bbdd0e99b058aee03eecc1 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 18 Nov 2022 21:11:55 +0000 Subject: [PATCH] check fixes --- R/opts.R | 5 +- inst/stan/data/covariates.stan | 3 +- inst/stan/data/simulation_rt.stan | 2 +- inst/stan/estimate_infections.stan | 74 ++++++------------- .../functions/{rt.stan => covariates.stan} | 10 +-- inst/stan/functions/generated_quantities.stan | 22 +----- inst/stan/functions/infections.stan | 2 +- inst/stan/simulate_infections.stan | 25 +++---- 8 files changed, 48 insertions(+), 95 deletions(-) rename inst/stan/functions/{rt.stan => covariates.stan} (84%) diff --git a/R/opts.R b/R/opts.R index f7af12f72..10015ab11 100644 --- a/R/opts.R +++ b/R/opts.R @@ -90,6 +90,7 @@ generation_time_opts <- function(..., disease, source, max = 15L, names(gt) <- paste0("gt_", names(gt)) return(gt) +} #' Delay Distribution Options #' @@ -338,7 +339,7 @@ process_opts <- function(model = "R", rw = 0, use_breakpoints = TRUE, future = "latest", - stationary = FALSE + stationary = FALSE, pop = 0) { ## check @@ -357,7 +358,7 @@ process_opts <- function(model = "R", use_breakpoints = use_breakpoints, future = future, stationary = stationary, - pop = pop, + pop = pop ) # replace default settings with those specified by user diff --git a/inst/stan/data/covariates.stan b/inst/stan/data/covariates.stan index b216dd9b7..56f0b1a1a 100644 --- a/inst/stan/data/covariates.stan +++ b/inst/stan/data/covariates.stan @@ -1,6 +1,7 @@ int process_model; // 0 = infections; 1 = growth; 2 = rt int bp_n; // no of breakpoints (0 = no breakpoints) int breakpoints[t - seeding_time]; // when do breakpoints occur +int cov_mean_const; // 0 = not const mean; 1 = const mean real cov_mean_mean[cov_mean_const]; // const covariate mean real cov_mean_sd[cov_mean_const]; // const covariate sd -real cov_t[cov_mean_const ? 0 : t] // time-varying covariate mean +vector[cov_mean_const ? 0 : t] cov_t; // time-varying covariate mean diff --git a/inst/stan/data/simulation_rt.stan b/inst/stan/data/simulation_rt.stan index 9fcfeb06b..3a9f7931e 100644 --- a/inst/stan/data/simulation_rt.stan +++ b/inst/stan/data/simulation_rt.stan @@ -4,5 +4,5 @@ real gt_sd[n, 1]; // sd of generation time int gt_max[1]; // maximum generation time int gt_dist[1]; // 0 = lognormal; 1 = gamma - matrix[n, t - seeding_time] R; // reproduction number + vector[n] R[t - seeding_time]; // reproduction number int pop; // susceptible population diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index b296917eb..428feae5f 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -28,8 +28,8 @@ transformed data{ int noise_terms = setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from); matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function // covariate mean - real cov_mean_logmean[cov_mean_const] = log(cov_mean^2 / sqrt(cov_sd^2 + cov_mean^2)); - real cov_mean_logsd[cov_mean_const] = sqrt(log(1 + (cov_sd^2 / cov_mean^2))); + real cov_mean_logmean[cov_mean_const]; + real cov_mean_logsd[cov_mean_const]; int delay_max_fixed = (n_fixed_delays == 0 ? 0 : sum(delay_max[fixed_delays]) - num_elements(fixed_delays) + 1); @@ -39,6 +39,11 @@ transformed data{ vector[truncation && trunc_fixed[1] ? trunc_max[1] : 0] trunc_fixed_pmf; vector[delay_max_fixed] delays_fixed_pmf; + if (cov_mean_const) { + cov_mean_logmean[1] = log(cov_mean_mean[1]^2 / sqrt(cov_mean_sd[1]^2 + cov_mean_mean[1]^2)); + cov_mean_logsd[1] = sqrt(log(1 + (cov_mean_sd[1]^2 / cov_mean_mean[1]^2))); + } + if (gt_fixed[1]) { gt_fixed_pmf = discretised_pmf(gt_mean_mean[1], gt_sd_mean[1], gt_max[1], gt_dist[1], 1); } @@ -90,8 +95,10 @@ transformed parameters { noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type); } // update covariates - cov = update_covariate(log_cov_mean, cov_t, noise, breakpoints, bp_effects, - stationary, ot_h, 0); + cov = update_covariate( + log_cov_mean, cov_t, noise, breakpoints, bp_effects, + stationary, ot_h + ); uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time); // Estimate latent infections if (process_model == 0) { @@ -157,20 +164,18 @@ model { trunc_dist, 1 ); if (estimate_r) { - // priors on Rt - rt_lp( - log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n, - seeding_time, r_logmean, r_logsd, prior_infections, prior_growth - ); // penalised_prior on generation interval delays_lp( gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, gt_dist, gt_weight ); } - // priors on Rt - covariate_lp(log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd); - infections_lp(initial_infections, initial_growth, prior_infections, prior_growth, - seeding_time); + // priors on covariates and infections + covariate_lp( + log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd + ); + infections_lp( + initial_infections, initial_growth, prior_infections, prior_growth, seeding_time + ); // prior observation scaling if (obs_scale) { frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; @@ -185,33 +190,11 @@ model { generated quantities { int imputed_reports[ot_h]; -<<<<<<< HEAD -<<<<<<< HEAD - vector[estimate_r > 0 ? 0: ot_h] gen_R; - real r[ot_h] - 1; - vector[return_likelihood ? ot : 0] log_lik; - if (estimate_r == 0){ -======= vector[estimate_r > 0 ? 0: ot_h] R; - real r[ot_h]; - vector[return_likelihood > 1 ? ot : 0] log_lik; - if (estimate_r){ - // estimate growth from estimated Rt - real set_gt_mean = (gt_mean_sd[1] > 0 ? gt_mean[1] : gt_mean_mean[1]); - real set_gt_sd = (gt_sd_sd [1]> 0 ? gt_sd[1] : gt_sd_mean[1]); - vector[gt_max[1]] gt_pmf = combine_pmfs(gt_fixed_pmf, gt_mean, gt_sd, gt_max, gt_dist, gt_max[1], 1); - } else { ->>>>>>> 7bc2510b (implement different model types) -======= - vector[estimate_r > 0 ? ot_h : 0] R; - real r[ot_h]; - vector[return_likelihood > 1 ? ot : 0] log_lik; -<<<<<<< HEAD - if (estimate_r) { ->>>>>>> fe3d94be (generate R if estimate_r > 0) -======= + real r[ot_h - 1]; + vector[return_likelihood ? ot : 0] log_lik; if (estimate_r == 0 && process_model != 2) { ->>>>>>> b520805f (update stan code to reflect model update) + // sample generation time real gt_mean_sample[1]; real gt_sd_sample[1]; @@ -225,19 +208,7 @@ generated quantities { ); // calculate Rt using infections and generation time -<<<<<<< HEAD -<<<<<<< HEAD - gen_R = calculate_Rt( - infections, seeding_time, gt_rev_pmf, rt_half_window -======= - // estimate growth from calculated Rt - R = calculate_Rt( - infections, seeding_time, gt_mean_sample, gt_sd_sample, - max_gt, rt_half_window ->>>>>>> 7bc2510b (implement different model types) - ); -======= - R = calculate_Rt(infections, seeding_time, gt_pmf); + R = calculate_Rt(infections, seeding_time, gt_rev_pmf); } else { R = cov; } @@ -245,7 +216,6 @@ generated quantities { r = calculate_growth(infections, seeding_time); } else { r = cov; ->>>>>>> b520805f (update stan code to reflect model update) } // simulate reported cases imputed_reports = report_rng(reports, rep_phi, obs_dist); diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/covariates.stan similarity index 84% rename from inst/stan/functions/rt.stan rename to inst/stan/functions/covariates.stan index 06cebad19..4f78e1c33 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/covariates.stan @@ -1,5 +1,5 @@ // update combined covariates -vector update_covariate(vector log_cov_mean, vector cov_t, +vector update_covariate(real[] log_cov_mean, vector cov_t, vector noise, int[] bps, real[] bp_effects, int stationary, int t) { @@ -10,7 +10,7 @@ vector update_covariate(vector log_cov_mean, vector cov_t, // define result vectors vector[t] bp = rep_vector(0, t); vector[t] gp = rep_vector(0, t); - vector[t] R; + vector[t] cov; // initialise breakpoints if (bp_n) { for (s in 1:t) { @@ -35,7 +35,7 @@ vector update_covariate(vector log_cov_mean, vector cov_t, } } if (num_elements(log_cov_mean) > 0) { - cov = rep(log_cov_mean, t); + cov = rep_vector(log_cov_mean[1], t); } else { cov = log(cov_t); } @@ -47,10 +47,10 @@ vector update_covariate(vector log_cov_mean, vector cov_t, void covariate_lp(real[] log_cov_mean, real[] bp_effects, real[] bp_sd, int bp_n, - real cov_mean_logmean, real cov_mean_logsd) { + real[] cov_mean_logmean, real[] cov_mean_logsd) { // initial prior if (num_elements(log_cov_mean) > 0) { - log_cov_mean ~ normal(cov_mean_logmean, cov_mean_logsd); + log_cov_mean ~ normal(cov_mean_logmean[1], cov_mean_logsd[1]); } //breakpoint effects on Rt if (bp_n > 0) { diff --git a/inst/stan/functions/generated_quantities.stan b/inst/stan/functions/generated_quantities.stan index 7ef21e15f..e8e72ba08 100644 --- a/inst/stan/functions/generated_quantities.stan +++ b/inst/stan/functions/generated_quantities.stan @@ -15,28 +15,12 @@ vector calculate_Rt(vector infections, int seeding_time, } return(R); } -// Convert an estimate of Rt to growth -real[] R_to_growth(vector R, real gt_mean, real gt_sd) { - int t = num_elements(R); - real r[t]; - if (gt_sd > 0) { - real k = pow(gt_sd / gt_mean, 2); - for (s in 1:t) { - r[s] = (pow(R[s], k) - 1) / (k * gt_mean); - } - } else { - // limit as gt_sd -> 0 - for (s in 1:t) { - r[s] = log(R[s]) / gt_mean; - } - } - return(r); -} + // Calculate growth rate -real[] calculate_growth(vector infections, int seeding_time) { +vector calculate_growth(vector infections, int seeding_time) { int t = num_elements(infections); int ot = t - seeding_time; vector[t] log_inf = log(infections); vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)]; - return(to_array_1d(growth)); + return(growth); } diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index 8aac27a76..1ca0af51d 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -29,7 +29,7 @@ vector generate_seed(real[] initial_infections, real[] initial_growth, int uot) return(seed_infs); } // generate infections using infectiousness -vector renewal_model(vector oR, vector uobs_infs, vector gt_rev_pmf, +vector renewal_model(vector oR, vector uobs_inf, vector gt_rev_pmf, int pop, int ht) { // time indices and storage int ot = num_elements(oR); diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 377c73218..b9c48bbc7 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -27,10 +27,10 @@ transformed data { generated quantities { // generated quantities - matrix[n, t] infections; //latent infections - matrix[n, t - seeding_time] reports; // observed cases + vector[n] infections[t]; //latent infections + vector[n] reports[t - seeding_time]; // observed cases int imputed_reports[n, t - seeding_time]; - real r[n, t - seeding_time]; + vector[n] r[t - seeding_time]; vector[seeding_time] uobs_inf; for (i in 1:n) { // generate infections from Rt trace @@ -47,24 +47,21 @@ generated quantities { uobs_inf = generate_seed(initial_infections[i], initial_growth[i], seeding_time); // generate infections from Rt trace - infections[i] = to_row_vector(renewal_model(to_vector(R[i]), uobs_inf, - gt_rev_pmf, pop, future_time)); + infections[i] = renewal_model(R[i], uobs_inf, gt_rev_pmf, pop, future_time); // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) - ); + reports[i] = convolve_to_report(infections[i], delay_rev_pmf, seeding_time); // weekly reporting effect if (week_effect > 1) { - reports[i] = to_row_vector( - day_of_week_effect(to_vector(reports[i]), day_of_week, - to_vector(day_of_week_simplex[i]))); + reports[i] = day_of_week_effect( + reports[i], day_of_week, to_vector(day_of_week_simplex[i]) + ); } // scale observations if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); + reports[i] = scale_obs(reports[i], frac_obs[i, 1]); } // simulate reported cases - imputed_reports[i] = report_rng(to_vector(reports[i]), rep_phi[i], obs_dist); - r[i] = calculate_growth(to_vector(infections[i]), seeding_time); + imputed_reports[i] = report_rng(reports[i], rep_phi[i], obs_dist); + r[i] = calculate_growth(infections[i], seeding_time); } }