Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Numpyro Discrete Enumeration #195

Open
LSZ2001 opened this issue Dec 11, 2023 · 1 comment
Open

Support for Numpyro Discrete Enumeration #195

LSZ2001 opened this issue Dec 11, 2023 · 1 comment

Comments

@LSZ2001
Copy link

LSZ2001 commented Dec 11, 2023

Hi @dfm,

I am building a Gaussian process model that contains discrete variables. Here is a simplified version:

import matplotlib.pyplot as plt
import jax
from jax import random, lax
import jax.numpy as jnp
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import numpy as onp
import tinygp
from tinygp import GaussianProcess, kernels, transforms

def model(X_train=None, y_train=None, X_test=None):
    with numpyro.plate("dimensions", X_train.shape[1]) as d:
        ls = numpyro.sample("ls", dist.Gamma(3, 0.5))
    with numpyro.plate("dimensions_comb", 2**X_train.shape[1]-1) as m:
        pis = numpyro.sample("pi", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) # Induces sparsity on the kernel structure
#     pis = jnp.array([0,1,1.0])
    
    kernel1 = transforms.Subspace(0, kernels.ExpSquared(ls[0]))
    kernel2 = transforms.Subspace(1, kernels.ExpSquared(ls[1]))
    kernel = pis[0]*kernel1 + pis[1]*kernel2 + pis[2]*kernel1*kernel2
    gp = GaussianProcess(kernel, X_train, diag=0.1)
    
    with jax.ensure_compile_time_eval():
        with numpyro.plate("data", X_train.shape[0]):
            numpyro.sample("gp", gp.numpyro_dist(), obs=y_train)
    if y_train is not None:
        with numpyro.plate("data", X_test.shape[0]):
            numpyro.deterministic("f", gp.condition(y_train, X_test).gp.loc)
    
# Data creation
onp.random.seed(0)
N = 100
X_train = (jnp.array(onp.random.uniform(size=(N,2)))-0.5)*10
y_train = 0.5*X_train[:,0] + 1*X_train[:,1] + 2*X_train[:,0]*X_train[:,1]
numpyro.render_model(model, model_args=(X_train,y_train,X_train), render_distributions=True, render_params=True)

# GP fitting
rng_key = random.PRNGKey(0)
num_chains = 1
hmc = MCMC(NUTS(model), num_samples=1000, num_warmup=1000, num_chains=num_chains)
hmc.run(rng_key, X_train, y_train, X_train)
hmc.print_summary(exclude_deterministic=True)  
hmc_samples = hmc.get_samples()
plt.errorbar(y_train, jnp.mean(hmc_samples["f"],axis=0), jnp.std(hmc_samples["f"],axis=0), color="k", fmt = '.',)

Running the code gives me the following error, as enumeration adds array dimensions to the kernel hyperparameters.

ValueError: The value of a constant kernel must be a scalar

Is there a workaround that would allow TinyGP kernels to be compatible with Numpyro enumeration? Thank you very much!

@dfm
Copy link
Owner

dfm commented Jan 16, 2024

I'm not very familiar with this feature of numpyro so could you say a few more words about what this is supposed to be doing and what behavior you want?

One note: I don't think this part of the model makes much sense, and might cause problems if I understand what is supposed to happen here:

        with numpyro.plate("data", X_train.shape[0]):
            numpyro.sample("gp", gp.numpyro_dist(), obs=y_train)

The problem here is that you can't have a plate over data points in a GP model, since the points are not independent. This means that mixture models are hard to implement using GPs, because you can't do closed form marginalization over mixture memberships in the straightforward way that you can with models where the data points are conditionally independent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants