Skip to content

Commit

Permalink
update stan code to reflect model update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored and seabbs committed Jan 27, 2023
1 parent 0a3d9bb commit 619a1a4
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 46 deletions.
1 change: 0 additions & 1 deletion inst/stan/data/backcalc.stan
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
int backcalc_prior; // Prior type to use for backcalculation
int rt_half_window; // Half the moving average window used when calculating Rt
5 changes: 4 additions & 1 deletion inst/stan/data/covariates.stan
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
int process_model; // 0 = infections; 1 = growth; 2 = rt
int process_model; // 0 = infections; 1 = growth; 2 = rt
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
real<lower = 0> cov_mean_mean[cov_mean_const]; // const covariate mean
real<lower = 0> cov_mean_sd[cov_mean_const]; // const covariate sd
real<lower = 0> cov_t[cov_mean_const ? 0 : t] // time-varying covariate mean
2 changes: 0 additions & 2 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
int estimate_r; // should the reproduction no be estimated (1 = yes)
real prior_infections; // prior for initial infections
real prior_growth; // prior on initial growth rate
real <lower = 0> r_mean; // prior mean of reproduction number
real <lower = 0> r_sd; // prior standard deviation of reproduction number
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
41 changes: 26 additions & 15 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ transformed data{
// gaussian process
int noise_terms = setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from);
matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function
// Rt
real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2));
real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2)));
// covariate mean
real cov_mean_logmean[cov_mean_const] = log(cov_mean^2 / sqrt(cov_sd^2 + cov_mean^2));
real cov_mean_logsd[cov_mean_const] = sqrt(log(1 + (cov_sd^2 / cov_mean^2)));

int delay_max_fixed = (n_fixed_delays == 0 ? 0 :
sum(delay_max[fixed_delays]) - num_elements(fixed_delays) + 1);
Expand Down Expand Up @@ -61,7 +61,7 @@ parameters{
real<lower = ls_min,upper=ls_max> rho[fixed ? 0 : 1]; // length scale of noise GP
real<lower = 0> alpha[fixed ? 0 : 1]; // scale of of noise GP
vector[fixed ? 0 : M] eta; // unconstrained noise
real base_cov; // covariate (R/r/inf)
real log_cov_mean[cov_mean_const]; // covariate (R/r/inf)
real initial_infections[estimate_r] ; // seed infections
real initial_growth[estimate_r && seeding_time > 1 ? 1 : 0]; // seed growth rate
real<upper = gt_max[1]> gt_mean[estimate_r && gt_mean_sd[1] > 0]; // mean of generation time (if uncertain)
Expand All @@ -82,15 +82,16 @@ transformed parameters {
vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process
vector[seeding_time] uobs_inf;
vector[t] infections; // latent infections
vector[ot_h] cov; // covaraites
vector[ot_h] cov; // covariates
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
// GP in noise - spectral densities
if (!fixed) {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
}
// update covariates
cov = update_covariate(base_cov, noise, breakpoints, bp_effects, stationary, ot_h, 0);
cov = update_covariate(log_cov_mean, cov_t, noise, breakpoints, bp_effects,
stationary, ot_h, 0);
uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time);
// Estimate latent infections
if (process_model == 0) {
Expand All @@ -101,14 +102,10 @@ transformed parameters {
infections = growth_model(cov, uobs_inf, future_time);
} else if (process_model == 2) {
// via Rt
vector[gt_max[1]] gt_rev_pmf;
vector[gt_max[1]] gt_pmf;
gt_rev_pmf = combine_pmfs(gt_fixed_pmf, gt_mean, gt_sd, gt_max, gt_dist, gt_max[1], 1, 1);
infections = renewal_model(cov, uobs_inf, gt_rev_pmf, pop, future_time);
} else {
// via deconvolution
infections = deconvolve_infections(
shifted_cases, noise, fixed, backcalc_prior
);
infections = renewal_model(cov, uobs_inf, gt_rev_pmf,
pop, future_time);
}
// convolve from latent infections to mean of observations
{
Expand Down Expand Up @@ -171,7 +168,7 @@ model {
);
}
// priors on Rt
covariate_lp(base_cov, bp_effects, bp_sd, bp_n, r_logmean, r_logsd);
covariate_lp(log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd);
infections_lp(initial_infections, initial_growth, prior_infections, prior_growth,
seeding_time);
// prior observation scaling
Expand Down Expand Up @@ -209,8 +206,12 @@ generated quantities {
vector[estimate_r > 0 ? ot_h : 0] R;
real r[ot_h];
vector[return_likelihood > 1 ? ot : 0] log_lik;
<<<<<<< HEAD
if (estimate_r) {
>>>>>>> fe3d94be (generate R if estimate_r > 0)
=======
if (estimate_r == 0 && process_model != 2) {
>>>>>>> b520805f (update stan code to reflect model update)
// sample generation time
real gt_mean_sample[1];
real gt_sd_sample[1];
Expand All @@ -224,6 +225,7 @@ generated quantities {
);

// calculate Rt using infections and generation time
<<<<<<< HEAD
<<<<<<< HEAD
gen_R = calculate_Rt(
infections, seeding_time, gt_rev_pmf, rt_half_window
Expand All @@ -234,8 +236,17 @@ generated quantities {
max_gt, rt_half_window
>>>>>>> 7bc2510b (implement different model types)
);
=======
R = calculate_Rt(infections, seeding_time, gt_pmf);
} else {
R = cov;
}
if (process_model != 1) {
r = calculate_growth(infections, seeding_time);
} else {
r = cov;
>>>>>>> b520805f (update stan code to reflect model update)
}
r = calculate_growth(infections, seeding_time);
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, obs_dist);
// log likelihood of model
Expand Down
17 changes: 2 additions & 15 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// calculate Rt directly from inferred infections
vector calculate_Rt(vector infections, int seeding_time,
vector gt_rev_pmf, int smooth) {
vector gt_rev_pmf) {
int t = num_elements(infections);
int ot = t - seeding_time;
vector[ot] R;
Expand All @@ -13,20 +13,7 @@ vector calculate_Rt(vector infections, int seeding_time,
);
R[s] = infections[s + seeding_time] / infectiousness[s];
}
if (smooth) {
for (s in 1:ot) {
real window = 0;
sR[s] = 0;
for (i in max(1, s - smooth):min(ot, s + smooth)) {
sR[s] += R[i];
window += 1;
}
sR[s] = sR[s] / window;
}
}else{
sR = R;
}
return(sR);
return(R);
}
// Convert an estimate of Rt to growth
real[] R_to_growth(vector R, real gt_mean, real gt_sd) {
Expand Down
5 changes: 3 additions & 2 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ vector growth_model(vector r, vector uobs_inf, int ht) {
int nht = ot - ht;
int t = ot + uot;
vector[t] infections = rep_vector(1e-5, t);
vector[ot] exp_r = exp(r);
vector[ot] obs_inf;
// Update observed infections
obs_inf[1] = uobs_inf[uot] * r[1];
obs_inf[1] = uobs_inf[uot] * exp_r[1];
for (i in 2:t) {
obs_inf[i] = obs_inf[i - 1] * r[i];
obs_inf[i] = obs_inf[i - 1] * exp_r[i];
}
infections[1:uot] = infections[1:uot] + uobs_inf;
infections[(uot + 1):t] = infections[(uot + 1):t] + obs_inf;
Expand Down
29 changes: 19 additions & 10 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// update a vector of Rts
vector update_Rt(int t, real log_R, vector noise, int[] bps,
real[] bp_effects, int stationary) {
// update combined covariates
vector update_covariate(vector log_cov_mean, vector cov_t,
vector noise, int[] bps,
real[] bp_effects, int stationary,
int t) {
// define control parameters
int bp_n = num_elements(bp_effects);
int bp_c = 0;
Expand Down Expand Up @@ -32,17 +34,24 @@ vector update_Rt(int t, real log_R, vector noise, int[] bps,
gp = cumulative_sum(gp);
}
}
// Calculate Rt
R = rep_vector(log_R, t) + bp + gp;
R = exp(R);
return(R);
if (num_elements(log_cov_mean) > 0) {
cov = rep(log_cov_mean, t);
} else {
cov = log(cov_t);
}
// Calculate combined covariates
cov = cov + bp + gp;
cov = exp(cov);
return(cov);
}

void covariate_lp(real base_cov,
void covariate_lp(real[] log_cov_mean,
real[] bp_effects, real[] bp_sd, int bp_n,
real r_logmean, real r_logsd) {
real cov_mean_logmean, real cov_mean_logsd) {
// initial prior
base_cov ~ normal(r_logmean, r_logsd);
if (num_elements(log_cov_mean) > 0) {
log_cov_mean ~ normal(cov_mean_logmean, cov_mean_logsd);
}
//breakpoint effects on Rt
if (bp_n > 0) {
bp_sd[1] ~ normal(0, 0.1) T[0,];
Expand Down

0 comments on commit 619a1a4

Please sign in to comment.