try:
    import tinygp
except ImportError:
    %pip install -q tinygp
    
try:
    import numpyro
except ImportError:
    %pip uninstall -y jax jaxlib
    %pip install -q numpyro jax jaxlib
    
try:
    import arviz
except ImportError:
    %pip install arviz
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Tutorial: Non-Gaussian Likelihoods

In this tutorial, we demonstrate how tinygp can be used in combination with non-Gaussian observation models. The basic idea is that your tinygp model can be a middle layer in your probabilistic model; it doesn’t just have to be at the bottom. One issue with this is that the marginalization over process realizations will no longer be analytic, so you’ll need to marginalize numerically using Markov chain Monte Carlo (MCMC) or variational inference (VI). Since tinygp doesn’t include any built-in inference methods you’ll need to use a different package, but luckily there are lots of good tools that exist! In this case, we’ll use numpyro.

As our test case, we’ll look at counts data with a Poisson observation model where the underlying log-rate is modeled by a Gaussian process (also known as a Cox process). To begin, let’s simulate some data:

import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(203618)
x = np.linspace(-3, 3, 20)
true_log_rate = 2 * np.cos(2 * x)
y = random.poisson(np.exp(true_log_rate))
plt.plot(x, y, ".k", label="data")
plt.plot(x, np.exp(true_log_rate), "C1", label="true rate")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")
../_images/likelihoods_2_0.png

Markov chain Monte Carlo (MCMC)

Then we set up the model in numpyro and run the MCMC, following the example in Sampling with numpyro. The main difference here is that we’re using MCMC to marginalize over GP realizations (that is encoded in the following by the fact that the log_rate parameter doesn’t have the obs=... argument set), instead of analytically marginalizing.

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from tinygp import kernels, GaussianProcess

# We'll enable float64 support here for better numerical performance
from jax.config import config

config.update("jax_enable_x64", True)


def model(x, y=None):
    # The parameters of the GP model
    mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
    rho = numpyro.sample("rho", dist.HalfNormal(10.0))

    # Set up the kernel and GP objects
    kernel = sigma ** 2 * kernels.Matern52(rho)
    gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean)

    # This parameter has shape (num_data,) and it encodes our beliefs about
    # the process rate in each bin
    log_rate = numpyro.sample("log_rate", gp.numpyro_dist())

    # Finally, our observation model is Poisson
    numpyro.sample("obs", dist.Poisson(jnp.exp(log_rate)), obs=y)


# Run the MCMC
nuts_kernel = numpyro.infer.NUTS(model, target_accept_prob=0.9)
mcmc = numpyro.infer.MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = jax.random.PRNGKey(55873)
mcmc.run(rng_key, x, y=y)
samples = mcmc.get_samples()
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/ipykernel_launcher.py:38: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.

We can summarize the MCMC results by plotting our inferred model (here we’re showing the 1- and 2-sigma credible regions), and compare it to the known ground truth:

q = np.percentile(samples["log_rate"], [5, 25, 50, 75, 95], axis=0)
plt.plot(x, np.exp(q[2]), color="C0", label="MCMC inferred rate")
plt.fill_between(x, np.exp(q[0]), np.exp(q[-1]), alpha=0.3, lw=0, color="C0")
plt.fill_between(x, np.exp(q[1]), np.exp(q[-2]), alpha=0.3, lw=0, color="C0")
plt.plot(x, np.exp(true_log_rate), "--", color="C1", label="true rate")
plt.plot(x, y, ".k", label="data")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")
../_images/likelihoods_6_0.png

Stochastic variational inference (SVI)

The above results look good and didn’t take long to run, but if we had more data the runtime could become prohibitive, since the number of parameters scales with the sizegreater_than the dataset. In cases like this, we can compute the relevant expectations using stochastic variational inference (SVI) instead of MCMC. Here’s one way that you could set up such a model using tinygp and numpyro:

def guide(x, y=None):
    numpyro.param("mean", jnp.zeros(()))
    numpyro.param("sigma", jnp.ones(()), constraint=dist.constraints.positive)
    numpyro.param(
        "rho", 2 * jnp.ones(()), constraint=dist.constraints.positive
    )
    mu = numpyro.param(
        "log_rate_mu", jnp.zeros_like(x) if y is None else jnp.log(y + 1)
    )
    sigma = numpyro.param(
        "log_rate_sigma",
        jnp.ones_like(x),
        constraint=dist.constraints.positive,
    )
    numpyro.sample("log_rate", dist.Normal(mu, sigma))


optim = numpyro.optim.Adam(0.01)
svi = numpyro.infer.SVI(model, guide, optim, numpyro.infer.Trace_ELBO(10))
results = svi.run(jax.random.PRNGKey(55873), 3000, x, y=y, progress_bar=False)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_442/551631826.py in <module>
     18 optim = numpyro.optim.Adam(0.01)
     19 svi = numpyro.infer.SVI(model, guide, optim, numpyro.infer.Trace_ELBO(10))
---> 20 results = svi.run(jax.random.PRNGKey(55873), 3000, x, y=y, progress_bar=False)

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, init_state, *args, **kwargs)
    340 
    341         if init_state is None:
--> 342             svi_state = self.init(rng_key, *args, **kwargs)
    343         else:
    344             svi_state = init_state

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/infer/svi.py in init(self, rng_key, *args, **kwargs)
    180         guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
    181         model_trace = trace(replay(model_init, guide_trace)).get_trace(
--> 182             *args, **kwargs, **self.static_kwargs
    183         )
    184         params = {}

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    169         :return: `OrderedDict` containing the execution trace.
    170         """
--> 171         self(*args, **kwargs)
    172         return self.trace
    173 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

/tmp/ipykernel_442/1231121528.py in model(x, y)
     13 def model(x, y=None):
     14     # The parameters of the GP model
---> 15     mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
     16     sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
     17     rho = numpyro.sample("rho", dist.HalfNormal(10.0))

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    217 
    218     # ...and use apply_stack to send it to the Messengers
--> 219     msg = apply_stack(initial_msg)
    220     return msg["value"]
    221 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in apply_stack(msg)
     45     pointer = 0
     46     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47         handler.process_message(msg)
     48         # When a Messenger sets the "stop" field of a message,
     49         # it prevents any Messengers above it on the stack from being applied.

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/handlers.py in process_message(self, msg)
    221                     return None
    222                 if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
--> 223                     raise RuntimeError(f"Site {name} must be sampled in trace.")
    224                 msg["value"] = guide_msg["value"]
    225                 msg["infer"] = guide_msg["infer"]

RuntimeError: Site mean must be sampled in trace.

As above, we can plot our inferred conditional model and compare it to the ground truth:

mu = results.params["log_rate_mu"]
sigma = results.params["log_rate_sigma"]
plt.plot(x, np.exp(mu), color="C0", label="VI inferred rate")
plt.fill_between(
    x,
    np.exp(mu - 2 * sigma),
    np.exp(mu + 2 * sigma),
    alpha=0.3,
    lw=0,
    color="C0",
)
plt.fill_between(
    x, np.exp(mu - sigma), np.exp(mu + sigma), alpha=0.3, lw=0, color="C0"
)
plt.plot(x, np.exp(true_log_rate), "--", color="C1", label="true rate")
plt.plot(x, y, ".k", label="data")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")