From caf16e9421534142df3397fbf8dd792e1df7e47c Mon Sep 17 00:00:00 2001 From: Adrien Corenflos Date: Thu, 5 Sep 2024 10:31:45 +0100 Subject: [PATCH] Plugging O(N) smoothing algos in PGibbs. This is not very useful for cheap models (as the cost of sampling a single trajectory is still O(N), not O(1)) but useful for models with expensive transition dynamics. --- book/pmcmc/pgibbs_ecological.py | 36 ++++++++++++++++++----------- book/smoothing/offline_smoothing.py | 2 +- particles/mcmc.py | 13 ++++++++++- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/book/pmcmc/pgibbs_ecological.py b/book/pmcmc/pgibbs_ecological.py index 2c1f2e2..704be9a 100644 --- a/book/pmcmc/pgibbs_ecological.py +++ b/book/pmcmc/pgibbs_ecological.py @@ -5,7 +5,6 @@ model (2nd numerical example in Chapter 16 on PMCMC, Figures 16.8 to 16.10). """ - from collections import OrderedDict import numpy as np @@ -22,12 +21,17 @@ # state-space model class ThetaLogisticReparametrised(ssms.ThetaLogistic): default_params = {'precX': 4., 'precY': 6.25, 'tau0': 0.15, - 'tau1': 0.12, 'tau2': 0.1} + 'tau1': 0.12, 'tau2': 0.1} + def __init__(self, **kwargs): ssms.ThetaLogistic.__init__(self, **kwargs) self.sigmaX = 1. / np.sqrt(self.precX) self.sigmaY = 1. / np.sqrt(self.precY) + def upper_bound_log_pt(self, t): + return -0.5 * np.log(2 * np.pi * self.sigmaX ** 2) + + ssm_cls = ThetaLogisticReparametrised # data @@ -46,6 +50,7 @@ def __init__(self, **kwargs): 'precX': r'$1/\sigma_X^2$', 'precY': r'$1/\sigma_Y^2$', 'x_0': r'$x_0$'} + # Particle Gibbs class PGibbs(mcmc.ParticleGibbs): def update_theta(self, theta, x): @@ -70,8 +75,8 @@ def update_theta(self, theta, x): log_prob = -np.inf else: new_deltaX = dax - tau0 + tau1 * np.exp(tau2_prop * ax[:-1]) - log_prob = 0.5 * new_theta['precX']* (np.sum(deltaX**2) - -np.sum(new_deltaX**2)) + log_prob = 0.5 * new_theta['precX'] * (np.sum(deltaX ** 2) + - np.sum(new_deltaX ** 2)) log_prob += (prior.laws['tau2'].logpdf(tau2_prop) - prior.laws['tau2'].logpdf(theta['tau2'])) if np.log(stats.uniform.rvs()) < log_prob: @@ -95,12 +100,12 @@ def update_theta(self, theta, x): xtx = np.dot(features.T, features) beta_ols = linalg.solve(xtx, np.matmul(features.T, dax)) muprior = np.array([prior.laws[p].mu for p in ['tau0', 'tau1']]) - Qprior = np.diag([prior.laws[p].sigma**(-2) for p in ['tau0', 'tau1']]) + Qprior = np.diag([prior.laws[p].sigma ** (-2) for p in ['tau0', 'tau1']]) Qpost = Qprior + new_theta['precX'] * xtx Sigpost = linalg.inv(Qpost) mpost = (np.matmul(Qprior, muprior) + np.matmul(Sigpost, new_theta['precX'] - * np.matmul(xtx, beta_ols))) + * np.matmul(xtx, beta_ols))) while True: # reject until tau0 and tau1 are > 0 v = stats.multivariate_normal.rvs(mean=mpost, cov=Sigpost) @@ -111,10 +116,13 @@ def update_theta(self, theta, x): return new_theta + algos = OrderedDict() niter = 10 ** 5 burnin = int(niter / 10) -for name, opt in zip(['pg-back', 'pg'], [True, False]): +for name, opt in zip( + ['pg-back', 'pg', 'pg-reject', 'pg-mcmc'], + [True, False, "reject", "mcmc"]): algos[name] = PGibbs(ssm_cls=ssm_cls, data=data, prior=prior, Nx=50, niter=niter, backward_step=opt, store_x=True, verbose=10) @@ -124,6 +132,7 @@ def update_theta(self, theta, x): alg.run() print('CPU time: %.2f min' % (alg.cpu_time / 60)) + # Update rates def update_rate(x): """Update rate. @@ -139,12 +148,13 @@ def update_rate(x): """ return np.mean(x[1:] != x[:-1], axis=0) + # PLOTS # ===== savefigs = True # False if you don't want to save plots as pdfs plt.style.use('ggplot') -colors = {'pg-back': 'black', 'pg': 'gray'} -linestyles = {'pg-back': '-', 'pg': '--'} +colors = {'pg-back': 'black', 'pg': 'gray', 'pg-reject': 'blue', 'pg-mcmc': 'red'} +linestyles = {'pg-back': '-', 'pg': '--', 'pg-reject': '-.', 'pg-mcmc': ':'} # Update rates of PG samplers plt.figure() @@ -156,7 +166,7 @@ def update_rate(x): plt.ylabel('update rate') plt.legend(loc=6) # center left if savefigs: - plt.savefig('ecological_update_rates.pdf') # Figure 16.8 + plt.savefig('ecological_update_rates.pdf') #  Figure 16.8 # pair plots from PG-back plt.figure() @@ -173,7 +183,7 @@ def update_rate(x): plt.ylabel(pretty_par_names[p2]) i += 1 if savefigs: - plt.savefig('ecological_pairplot_taus.pdf') # Figure 16.10 + plt.savefig('ecological_pairplot_taus.pdf') #  Figure 16.10 # MCMC traces plt.figure() @@ -204,10 +214,10 @@ def update_rate(x): for i, p in enumerate(list(dict_prior.keys()) + ['x_0']): plt.subplot(2, 3, i + 1) for alg_name, alg in algos.items(): - th = alg.chain.x[:, 0] if p=='x_0' else alg.chain.theta[p] + th = alg.chain.x[:, 0] if p == 'x_0' else alg.chain.theta[p] acf_th = acf(th[burnin:], nlags=nlags, fft=True) plt.plot(acf_th, label=alg_name, color=colors[alg_name], - linestyle=linestyles[alg_name]) + linestyle=linestyles[alg_name]) plt.axis([0, nlags, -0.03, 1.]) plt.xlabel('lag') plt.ylabel(pretty_par_names[p]) diff --git a/book/smoothing/offline_smoothing.py b/book/smoothing/offline_smoothing.py index 2b3ce70..3666470 100644 --- a/book/smoothing/offline_smoothing.py +++ b/book/smoothing/offline_smoothing.py @@ -49,7 +49,7 @@ class DiscreteCox_with_add_f(ssms.DiscreteCox): """ def upper_bound_log_pt(self, t): - return -0.5 * np.log(2 * np.pi * self.sigma ** 2) + return -0.5 * np.log(2 * np.pi * self.sigmaX ** 2) # Aim is to compute the smoothing expectation of diff --git a/particles/mcmc.py b/particles/mcmc.py index d1030bd..ab0a481 100644 --- a/particles/mcmc.py +++ b/particles/mcmc.py @@ -583,6 +583,7 @@ def __init__( regenerate_data=False, backward_step=False, store_x=False, + backward_step_kwargs=None, ): GenericGibbs.__init__( self, @@ -594,10 +595,13 @@ def __init__( theta0=theta0, store_x=store_x, ) + if backward_step_kwargs is None: + backward_step_kwargs = {} self.Nx = Nx self.fk_cls = ssms.Bootstrap if fk_cls is None else fk_cls self.regenerate_data = regenerate_data self.backward_step = backward_step + self._backward_step_kwargs = backward_step_kwargs if backward_step_kwargs is not None else {} def fk_mod(self, theta): ssm = self.ssm_cls(**ssp.rec_to_dict(theta)) @@ -610,10 +614,17 @@ def update_states(self, theta, x): else: cpf = CSMC(fk=fk, N=self.Nx, xstar=x) cpf.run() - if self.backward_step: + if isinstance(self.backward_step, str): + if hasattr(cpf.hist, self.backward_step): + method = getattr(cpf.hist, self.backward_step) + else: + method = getattr(cpf.hist, "backward_sampling_" + self.backward_step) + new_x = method(1, **self._backward_step_kwargs) + elif self.backward_step: # need to check if it is exactly the True object new_x = cpf.hist.backward_sampling_ON2(1) else: new_x = cpf.hist.extract_one_trajectory() + if self.regenerate_data: self.data = fk.ssm.simulate_given_x(new_x) return new_x