Skip to content

Commit

Permalink
fix nofeatures model (remove coef)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Mar 27, 2022
1 parent 52ed01e commit f44ff48
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions pyrocov/mutrans.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,8 @@ def model(dataset, model_type, *, forecast_steps=None):
# Configure reparametrization (which does not affect model density).
reparam = {}
if "reparam" in model_type:
reparam["coef"] = LocScaleReparam()
if "nofeatures" not in model_type:
reparam["coef"] = LocScaleReparam()
if "localrate" in model_type or "nofeatures" in model_type:
reparam["rate_loc"] = LocScaleReparam()
if "localinit" in model_type:
Expand All @@ -501,22 +502,22 @@ def model(dataset, model_type, *, forecast_steps=None):
with poutine.reparam(config=reparam):

# Sample global random variables.
coef_scale = pyro.sample("coef_scale", dist.LogNormal(-4, 2))
if "nofeatures" not in model_type:
coef_scale = pyro.sample("coef_scale", dist.LogNormal(-4, 2))
rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
if "localrate" in model_type:
rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))
if "nofeatures" in model_type:
if "localrate" or "nofeatures" in model_type:
rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))
if "localinit" in model_type:
init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))

# Assume relative growth rate depends strongly on mutations and weakly
# on clade and place. Assume initial infections depend strongly on
# clade and place.
coef = pyro.sample(
"coef", dist.Laplace(torch.zeros(F), coef_scale).to_event(1)
) # [F]
if "nofeatures" not in model_type:
coef = pyro.sample(
"coef", dist.Laplace(torch.zeros(F), coef_scale).to_event(1)
) # [F]
with clade_plate:
if "localrate" in model_type:
rate_loc = pyro.sample(
Expand Down Expand Up @@ -959,33 +960,35 @@ def log_stats(dataset: dict, result: dict) -> dict:
stats = {k: float(v) for k, v in result["median"].items() if v.numel() == 1}
stats["loss"] = float(np.median(result["losses"][-100:]))
mutations = dataset["mutations"]
mean = result["mean"]["coef"].cpu()
if not mean.shape:
return stats # Work around error in map estimation.
logger.info(
"Dense data has shape {} totaling {} sequences".format(
" x ".join(map(str, dataset["weekly_clades"].shape)),
int(dataset["weekly_clades"].sum()),

if "coef" in result["mean"]:
mean = result["mean"]["coef"].cpu()
if not mean.shape:
return stats # Work around error in map estimation.
logger.info(
"Dense data has shape {} totaling {} sequences".format(
" x ".join(map(str, dataset["weekly_clades"].shape)),
int(dataset["weekly_clades"].sum()),
)
)
)

# Statistical significance.
std = result["std"]["coef"].cpu()
sig = mean.abs() / std
logger.info(f"|μ|/σ [median,max] = [{sig.median():0.3g},{sig.max():0.3g}]")
stats["|μ|/σ median"] = sig.median()
stats["|μ|/σ max"] = sig.max()
# Statistical significance.
std = result["std"]["coef"].cpu()
sig = mean.abs() / std
logger.info(f"|μ|/σ [median,max] = [{sig.median():0.3g},{sig.max():0.3g}]")
stats["|μ|/σ median"] = sig.median()
stats["|μ|/σ max"] = sig.max()

# Effects of individual mutations.
for name in ["S:D614G", "S:N501Y", "S:E484K", "S:L452R"]:
if name not in mutations:
continue
i = mutations.index(name)
m = mean[i] * 0.01
s = std[i] * 0.01
logger.info(f"ΔlogR({name}) = {m:0.3g} ± {s:0.2f}")
stats[f"ΔlogR({name}) mean"] = m
stats[f"ΔlogR({name}) std"] = s
# Effects of individual mutations.
for name in ["S:D614G", "S:N501Y", "S:E484K", "S:L452R"]:
if name not in mutations:
continue
i = mutations.index(name)
m = mean[i] * 0.01
s = std[i] * 0.01
logger.info(f"ΔlogR({name}) = {m:0.3g} ± {s:0.2f}")
stats[f"ΔlogR({name}) mean"] = m
stats[f"ΔlogR({name}) std"] = s

# Growth rates of individual clades.
rate = quotient_central_moments(
Expand Down

0 comments on commit f44ff48

Please sign in to comment.