Tutorial: Non-Gaussian Likelihoods
Contents
try:
import tinygp
except ImportError:
%pip install -q tinygp
try:
import numpyro
except ImportError:
%pip uninstall -y jax jaxlib
%pip install -q numpyro jax jaxlib
try:
import arviz
except ImportError:
%pip install arviz
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Tutorial: Non-Gaussian Likelihoods¶
In this tutorial, we demonstrate how tinygp
can be used in combination with non-Gaussian observation models.
The basic idea is that your tinygp
model can be a middle layer in your probabilistic model; it doesn’t just have to be at the bottom.
One issue with this is that the marginalization over process realizations will no longer be analytic, so you’ll need to marginalize numerically using Markov chain Monte Carlo (MCMC) or variational inference (VI).
Since tinygp
doesn’t include any built-in inference methods you’ll need to use a different package, but luckily there are lots of good tools that exist!
In this case, we’ll use numpyro
.
As our test case, we’ll look at counts data with a Poisson observation model where the underlying log-rate is modeled by a Gaussian process (also known as a Cox process). To begin, let’s simulate some data:
import numpy as np
import matplotlib.pyplot as plt
random = np.random.default_rng(203618)
x = np.linspace(-3, 3, 20)
true_log_rate = 2 * np.cos(2 * x)
y = random.poisson(np.exp(true_log_rate))
plt.plot(x, y, ".k", label="data")
plt.plot(x, np.exp(true_log_rate), "C1", label="true rate")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")
Markov chain Monte Carlo (MCMC)¶
Then we set up the model in numpyro
and run the MCMC, following the example in Sampling with numpyro.
The main difference here is that we’re using MCMC to marginalize over GP realizations (that is encoded in the following by the fact that the log_rate
parameter doesn’t have the obs=...
argument set), instead of analytically marginalizing.
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from tinygp import kernels, GaussianProcess
# We'll enable float64 support here for better numerical performance
from jax.config import config
config.update("jax_enable_x64", True)
def model(x, y=None):
# The parameters of the GP model
mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
rho = numpyro.sample("rho", dist.HalfNormal(10.0))
# Set up the kernel and GP objects
kernel = sigma ** 2 * kernels.Matern52(rho)
gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean)
# This parameter has shape (num_data,) and it encodes our beliefs about
# the process rate in each bin
log_rate = numpyro.sample("log_rate", gp.numpyro_dist())
# Finally, our observation model is Poisson
numpyro.sample("obs", dist.Poisson(jnp.exp(log_rate)), obs=y)
# Run the MCMC
nuts_kernel = numpyro.infer.NUTS(model, target_accept_prob=0.9)
mcmc = numpyro.infer.MCMC(
nuts_kernel,
num_warmup=1000,
num_samples=1000,
num_chains=2,
progress_bar=False,
)
rng_key = jax.random.PRNGKey(55873)
mcmc.run(rng_key, x, y=y)
samples = mcmc.get_samples()
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/ipykernel_launcher.py:38: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
We can summarize the MCMC results by plotting our inferred model (here we’re showing the 1- and 2-sigma credible regions), and compare it to the known ground truth:
q = np.percentile(samples["log_rate"], [5, 25, 50, 75, 95], axis=0)
plt.plot(x, np.exp(q[2]), color="C0", label="MCMC inferred rate")
plt.fill_between(x, np.exp(q[0]), np.exp(q[-1]), alpha=0.3, lw=0, color="C0")
plt.fill_between(x, np.exp(q[1]), np.exp(q[-2]), alpha=0.3, lw=0, color="C0")
plt.plot(x, np.exp(true_log_rate), "--", color="C1", label="true rate")
plt.plot(x, y, ".k", label="data")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")
Stochastic variational inference (SVI)¶
The above results look good and didn’t take long to run, but if we had more data the runtime could become prohibitive, since the number of parameters scales with the sizegreater_than the dataset.
In cases like this, we can compute the relevant expectations using stochastic variational inference (SVI) instead of MCMC.
Here’s one way that you could set up such a model using tinygp
and numpyro
:
def guide(x, y=None):
numpyro.param("mean", jnp.zeros(()))
numpyro.param("sigma", jnp.ones(()), constraint=dist.constraints.positive)
numpyro.param(
"rho", 2 * jnp.ones(()), constraint=dist.constraints.positive
)
mu = numpyro.param(
"log_rate_mu", jnp.zeros_like(x) if y is None else jnp.log(y + 1)
)
sigma = numpyro.param(
"log_rate_sigma",
jnp.ones_like(x),
constraint=dist.constraints.positive,
)
numpyro.sample("log_rate", dist.Normal(mu, sigma))
optim = numpyro.optim.Adam(0.01)
svi = numpyro.infer.SVI(model, guide, optim, numpyro.infer.Trace_ELBO(10))
results = svi.run(jax.random.PRNGKey(55873), 3000, x, y=y, progress_bar=False)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_442/551631826.py in <module>
18 optim = numpyro.optim.Adam(0.01)
19 svi = numpyro.infer.SVI(model, guide, optim, numpyro.infer.Trace_ELBO(10))
---> 20 results = svi.run(jax.random.PRNGKey(55873), 3000, x, y=y, progress_bar=False)
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, init_state, *args, **kwargs)
340
341 if init_state is None:
--> 342 svi_state = self.init(rng_key, *args, **kwargs)
343 else:
344 svi_state = init_state
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/infer/svi.py in init(self, rng_key, *args, **kwargs)
180 guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
181 model_trace = trace(replay(model_init, guide_trace)).get_trace(
--> 182 *args, **kwargs, **self.static_kwargs
183 )
184 params = {}
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/tmp/ipykernel_442/1231121528.py in model(x, y)
13 def model(x, y=None):
14 # The parameters of the GP model
---> 15 mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
16 sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
17 rho = numpyro.sample("rho", dist.HalfNormal(10.0))
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
217
218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
220 return msg["value"]
221
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/handlers.py in process_message(self, msg)
221 return None
222 if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
--> 223 raise RuntimeError(f"Site {name} must be sampled in trace.")
224 msg["value"] = guide_msg["value"]
225 msg["infer"] = guide_msg["infer"]
RuntimeError: Site mean must be sampled in trace.
As above, we can plot our inferred conditional model and compare it to the ground truth:
mu = results.params["log_rate_mu"]
sigma = results.params["log_rate_sigma"]
plt.plot(x, np.exp(mu), color="C0", label="VI inferred rate")
plt.fill_between(
x,
np.exp(mu - 2 * sigma),
np.exp(mu + 2 * sigma),
alpha=0.3,
lw=0,
color="C0",
)
plt.fill_between(
x, np.exp(mu - sigma), np.exp(mu + sigma), alpha=0.3, lw=0, color="C0"
)
plt.plot(x, np.exp(true_log_rate), "--", color="C1", label="true rate")
plt.plot(x, y, ".k", label="data")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")