Skip to content

Commit

Permalink
check fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored and seabbs committed Jan 27, 2023
1 parent 9167cb4 commit 8fe772a
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 95 deletions.
5 changes: 3 additions & 2 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ generation_time_opts <- function(..., disease, source, max = 15L,
names(gt) <- paste0("gt_", names(gt))

return(gt)
}

#' Delay Distribution Options
#'
Expand Down Expand Up @@ -338,7 +339,7 @@ process_opts <- function(model = "R",
rw = 0,
use_breakpoints = TRUE,
future = "latest",
stationary = FALSE
stationary = FALSE,
pop = 0) {

## check
Expand All @@ -357,7 +358,7 @@ process_opts <- function(model = "R",
use_breakpoints = use_breakpoints,
future = future,
stationary = stationary,
pop = pop,
pop = pop
)

# replace default settings with those specified by user
Expand Down
3 changes: 2 additions & 1 deletion inst/stan/data/covariates.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
int process_model; // 0 = infections; 1 = growth; 2 = rt
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
int cov_mean_const; // 0 = not const mean; 1 = const mean
real<lower = 0> cov_mean_mean[cov_mean_const]; // const covariate mean
real<lower = 0> cov_mean_sd[cov_mean_const]; // const covariate sd
real<lower = 0> cov_t[cov_mean_const ? 0 : t] // time-varying covariate mean
vector<lower = 0>[cov_mean_const ? 0 : t] cov_t; // time-varying covariate mean
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
real<lower = 0> gt_sd[n, 1]; // sd of generation time
int<lower = 1> gt_max[1]; // maximum generation time
int gt_dist[1]; // 0 = lognormal; 1 = gamma
matrix[n, t - seeding_time] R; // reproduction number
vector[n] R[t - seeding_time]; // reproduction number
int pop; // susceptible population
74 changes: 22 additions & 52 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ transformed data{
int noise_terms = setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from);
matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function
// covariate mean
real cov_mean_logmean[cov_mean_const] = log(cov_mean^2 / sqrt(cov_sd^2 + cov_mean^2));
real cov_mean_logsd[cov_mean_const] = sqrt(log(1 + (cov_sd^2 / cov_mean^2)));
real cov_mean_logmean[cov_mean_const];
real cov_mean_logsd[cov_mean_const];

int delay_max_fixed = (n_fixed_delays == 0 ? 0 :
sum(delay_max[fixed_delays]) - num_elements(fixed_delays) + 1);
Expand All @@ -39,6 +39,11 @@ transformed data{
vector[truncation && trunc_fixed[1] ? trunc_max[1] : 0] trunc_fixed_pmf;
vector[delay_max_fixed] delays_fixed_pmf;

if (cov_mean_const) {
cov_mean_logmean[1] = log(cov_mean_mean[1]^2 / sqrt(cov_mean_sd[1]^2 + cov_mean_mean[1]^2));
cov_mean_logsd[1] = sqrt(log(1 + (cov_mean_sd[1]^2 / cov_mean_mean[1]^2)));
}

if (gt_fixed[1]) {
gt_fixed_pmf = discretised_pmf(gt_mean_mean[1], gt_sd_mean[1], gt_max[1], gt_dist[1], 1);
}
Expand Down Expand Up @@ -90,8 +95,10 @@ transformed parameters {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
}
// update covariates
cov = update_covariate(log_cov_mean, cov_t, noise, breakpoints, bp_effects,
stationary, ot_h, 0);
cov = update_covariate(
log_cov_mean, cov_t, noise, breakpoints, bp_effects,
stationary, ot_h
);
uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time);
// Estimate latent infections
if (process_model == 0) {
Expand Down Expand Up @@ -157,20 +164,18 @@ model {
trunc_dist, 1
);
if (estimate_r) {
// priors on Rt
rt_lp(
log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n,
seeding_time, r_logmean, r_logsd, prior_infections, prior_growth
);
// penalised_prior on generation interval
delays_lp(
gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, gt_dist, gt_weight
);
}
// priors on Rt
covariate_lp(log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd);
infections_lp(initial_infections, initial_growth, prior_infections, prior_growth,
seeding_time);
// priors on covariates and infections
covariate_lp(
log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd
);
infections_lp(
initial_infections, initial_growth, prior_infections, prior_growth, seeding_time
);
// prior observation scaling
if (obs_scale) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
Expand All @@ -185,33 +190,11 @@ model {

generated quantities {
int imputed_reports[ot_h];
<<<<<<< HEAD
<<<<<<< HEAD
vector[estimate_r > 0 ? 0: ot_h] gen_R;
real r[ot_h] - 1;
vector[return_likelihood ? ot : 0] log_lik;
if (estimate_r == 0){
=======
vector[estimate_r > 0 ? 0: ot_h] R;
real r[ot_h];
vector[return_likelihood > 1 ? ot : 0] log_lik;
if (estimate_r){
// estimate growth from estimated Rt
real set_gt_mean = (gt_mean_sd[1] > 0 ? gt_mean[1] : gt_mean_mean[1]);
real set_gt_sd = (gt_sd_sd [1]> 0 ? gt_sd[1] : gt_sd_mean[1]);
vector[gt_max[1]] gt_pmf = combine_pmfs(gt_fixed_pmf, gt_mean, gt_sd, gt_max, gt_dist, gt_max[1], 1);
} else {
>>>>>>> 7bc2510b (implement different model types)
=======
vector[estimate_r > 0 ? ot_h : 0] R;
real r[ot_h];
vector[return_likelihood > 1 ? ot : 0] log_lik;
<<<<<<< HEAD
if (estimate_r) {
>>>>>>> fe3d94be (generate R if estimate_r > 0)
=======
real r[ot_h - 1];
vector[return_likelihood ? ot : 0] log_lik;
if (estimate_r == 0 && process_model != 2) {
>>>>>>> b520805f (update stan code to reflect model update)

// sample generation time
real gt_mean_sample[1];
real gt_sd_sample[1];
Expand All @@ -225,27 +208,14 @@ generated quantities {
);

// calculate Rt using infections and generation time
<<<<<<< HEAD
<<<<<<< HEAD
gen_R = calculate_Rt(
infections, seeding_time, gt_rev_pmf, rt_half_window
=======
// estimate growth from calculated Rt
R = calculate_Rt(
infections, seeding_time, gt_mean_sample, gt_sd_sample,
max_gt, rt_half_window
>>>>>>> 7bc2510b (implement different model types)
);
=======
R = calculate_Rt(infections, seeding_time, gt_pmf);
R = calculate_Rt(infections, seeding_time, gt_rev_pmf);
} else {
R = cov;
}
if (process_model != 1) {
r = calculate_growth(infections, seeding_time);
} else {
r = cov;
>>>>>>> b520805f (update stan code to reflect model update)
}
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, obs_dist);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// update combined covariates
vector update_covariate(vector log_cov_mean, vector cov_t,
vector update_covariate(real[] log_cov_mean, vector cov_t,
vector noise, int[] bps,
real[] bp_effects, int stationary,
int t) {
Expand All @@ -10,7 +10,7 @@ vector update_covariate(vector log_cov_mean, vector cov_t,
// define result vectors
vector[t] bp = rep_vector(0, t);
vector[t] gp = rep_vector(0, t);
vector[t] R;
vector[t] cov;
// initialise breakpoints
if (bp_n) {
for (s in 1:t) {
Expand All @@ -35,7 +35,7 @@ vector update_covariate(vector log_cov_mean, vector cov_t,
}
}
if (num_elements(log_cov_mean) > 0) {
cov = rep(log_cov_mean, t);
cov = rep_vector(log_cov_mean[1], t);
} else {
cov = log(cov_t);
}
Expand All @@ -47,10 +47,10 @@ vector update_covariate(vector log_cov_mean, vector cov_t,

void covariate_lp(real[] log_cov_mean,
real[] bp_effects, real[] bp_sd, int bp_n,
real cov_mean_logmean, real cov_mean_logsd) {
real[] cov_mean_logmean, real[] cov_mean_logsd) {
// initial prior
if (num_elements(log_cov_mean) > 0) {
log_cov_mean ~ normal(cov_mean_logmean, cov_mean_logsd);
log_cov_mean ~ normal(cov_mean_logmean[1], cov_mean_logsd[1]);
}
//breakpoint effects on Rt
if (bp_n > 0) {
Expand Down
22 changes: 3 additions & 19 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,12 @@ vector calculate_Rt(vector infections, int seeding_time,
}
return(R);
}
// Convert an estimate of Rt to growth
real[] R_to_growth(vector R, real gt_mean, real gt_sd) {
int t = num_elements(R);
real r[t];
if (gt_sd > 0) {
real k = pow(gt_sd / gt_mean, 2);
for (s in 1:t) {
r[s] = (pow(R[s], k) - 1) / (k * gt_mean);
}
} else {
// limit as gt_sd -> 0
for (s in 1:t) {
r[s] = log(R[s]) / gt_mean;
}
}
return(r);
}

// Calculate growth rate
real[] calculate_growth(vector infections, int seeding_time) {
vector calculate_growth(vector infections, int seeding_time) {
int t = num_elements(infections);
int ot = t - seeding_time;
vector[t] log_inf = log(infections);
vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)];
return(to_array_1d(growth));
return(growth);
}
2 changes: 1 addition & 1 deletion inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ vector generate_seed(real[] initial_infections, real[] initial_growth, int uot)
return(seed_infs);
}
// generate infections using infectiousness
vector renewal_model(vector oR, vector uobs_infs, vector gt_rev_pmf,
vector renewal_model(vector oR, vector uobs_inf, vector gt_rev_pmf,
int pop, int ht) {
// time indices and storage
int ot = num_elements(oR);
Expand Down
25 changes: 11 additions & 14 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ transformed data {

generated quantities {
// generated quantities
matrix[n, t] infections; //latent infections
matrix[n, t - seeding_time] reports; // observed cases
vector[n] infections[t]; //latent infections
vector[n] reports[t - seeding_time]; // observed cases
int imputed_reports[n, t - seeding_time];
real r[n, t - seeding_time];
vector[n] r[t - seeding_time];
vector[seeding_time] uobs_inf;
for (i in 1:n) {
// generate infections from Rt trace
Expand All @@ -47,24 +47,21 @@ generated quantities {

uobs_inf = generate_seed(initial_infections[i], initial_growth[i], seeding_time);
// generate infections from Rt trace
infections[i] = to_row_vector(renewal_model(to_vector(R[i]), uobs_inf,
gt_rev_pmf, pop, future_time));
infections[i] = renewal_model(R[i], uobs_inf, gt_rev_pmf, pop, future_time);
// convolve from latent infections to mean of observations
reports[i] = to_row_vector(convolve_to_report(
to_vector(infections[i]), delay_rev_pmf, seeding_time)
);
reports[i] = convolve_to_report(infections[i], delay_rev_pmf, seeding_time);
// weekly reporting effect
if (week_effect > 1) {
reports[i] = to_row_vector(
day_of_week_effect(to_vector(reports[i]), day_of_week,
to_vector(day_of_week_simplex[i])));
reports[i] = day_of_week_effect(
reports[i], day_of_week, to_vector(day_of_week_simplex[i])
);
}
// scale observations
if (obs_scale) {
reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1]));
reports[i] = scale_obs(reports[i], frac_obs[i, 1]);
}
// simulate reported cases
imputed_reports[i] = report_rng(to_vector(reports[i]), rep_phi[i], obs_dist);
r[i] = calculate_growth(to_vector(infections[i]), seeding_time);
imputed_reports[i] = report_rng(reports[i], rep_phi[i], obs_dist);
r[i] = calculate_growth(infections[i], seeding_time);
}
}

0 comments on commit 8fe772a

Please sign in to comment.