try:
    import tinygp
except ImportError:
    !pip install -q tinygp
    
try:
    import numpyro
except ImportError:
    !pip uninstall -y jax jaxlib
    !pip install -q numpyro jax jaxlib
    
try:
    import arviz
except ImportError:
    !pip install arviz
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/latest/lib/python3.7/site-packages/jax/experimental/optimizers.py:30: FutureWarning: jax.experimental.optimizers is deprecated, import jax.example_libraries.optimizers instead
  FutureWarning)
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/latest/lib/python3.7/site-packages/jax/experimental/stax.py:30: FutureWarning: jax.experimental.stax is deprecated, import jax.example_libraries.stax instead
  FutureWarning)

Alternative likelihoodsΒΆ

import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(203618)
x = np.linspace(-3, 3, 20)
true_log_rate = 2 * np.cos(2 * x)
y = random.poisson(np.exp(true_log_rate))
plt.plot(x, y, ".k", label="data")
plt.plot(x, np.exp(true_log_rate), "C1", label="true rate")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")
../_images/poisson_2_0.png
from jax.config import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from tinygp import kernels, GaussianProcess


def model(x, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
    rho = numpyro.sample("rho", dist.HalfNormal(10.0))
    kernel = sigma ** 2 * kernels.Matern32(rho)
    gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean)
    log_rate = numpyro.sample(
        "log_rate",
        dist.MultivariateNormal(loc=gp.loc, scale_tril=gp.scale_tril),
    )
    numpyro.sample("obs", dist.Poisson(jnp.exp(log_rate)), obs=y)


nuts_kernel = NUTS(model, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = random.PRNGKey(55873)
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/latest/lib/python3.7/site-packages/numpyro/infer/mcmc.py:280: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  self.num_chains, local_device_count(), self.num_chains
%%time
mcmc.run(rng_key, x, y=y)
samples = mcmc.get_samples()
_ = samples["log_rate"].block_until_ready()
CPU times: user 12.6 s, sys: 75.6 ms, total: 12.7 s
Wall time: 12.7 s
import arviz as az

data = az.from_numpyro(mcmc)
az.summary(
    data, var_names=[v for v in data.posterior.data_vars if v != "log_rate"]
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mean 0.343 1.192 -2.070 2.485 0.034 0.026 1267.0 1186.0 1.0
rho 1.405 0.718 0.313 2.819 0.040 0.029 363.0 280.0 1.0
sigma 2.697 1.209 0.910 5.037 0.038 0.027 942.0 1113.0 1.0
q = np.percentile(samples["log_rate"], [5, 25, 50, 75, 95], axis=0)
plt.plot(x, y, ".k", label="data")
plt.plot(x, np.exp(true_log_rate), color="C1", label="true rate")
plt.plot(x, np.exp(q[2]), color="C0", label="inferred rate")
plt.fill_between(x, np.exp(q[0]), np.exp(q[-1]), alpha=0.3, lw=0, color="C0")
plt.fill_between(x, np.exp(q[1]), np.exp(q[-2]), alpha=0.3, lw=0, color="C0")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")
../_images/poisson_6_0.png