try:
    import tinygp
except ImportError:
    !pip install -q tinygp
    
try:
    import numpyro
except ImportError:
    !pip uninstall -y jax jaxlib
    !pip install -q numpyro jax jaxlib
    
try:
    import arviz
except ImportError:
    !pip install arviz
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/latest/lib/python3.7/site-packages/jax/experimental/optimizers.py:30: FutureWarning: jax.experimental.optimizers is deprecated, import jax.example_libraries.optimizers instead
  FutureWarning)
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/latest/lib/python3.7/site-packages/jax/experimental/stax.py:30: FutureWarning: jax.experimental.stax is deprecated, import jax.example_libraries.stax instead
  FutureWarning)

Posterior inference using numpyroΒΆ

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

t = np.sort(
    np.append(
        np.random.uniform(0, 3.8, 28),
        np.random.uniform(5.5, 10, 18),
    )
)
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
    0.2 * (t - 5)
    + np.sin(3 * t + 0.1 * (t - 5) ** 2)
    + yerr * np.random.randn(len(t))
)

true_t = np.linspace(0, 10, 100)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
_ = plt.title("simulated data")
../_images/numpyro_2_0.png
from jax.config import config

config.update("jax_enable_x64", True)

from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from tinygp import kernels, GaussianProcess


prior_sigma = 5.0


def numpyro_model(t, yerr, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
    jitter = numpyro.sample("jitter", dist.HalfNormal(prior_sigma))

    sigma1 = numpyro.sample("sigma1", dist.HalfNormal(prior_sigma))
    rho1 = numpyro.sample("rho1", dist.HalfNormal(prior_sigma))
    tau = numpyro.sample("tau", dist.HalfNormal(prior_sigma))
    kernel1 = sigma1 ** 2 * kernels.ExpSquared(tau) * kernels.Cosine(rho1)

    sigma2 = numpyro.sample("sigma2", dist.HalfNormal(prior_sigma))
    rho2 = numpyro.sample("rho2", dist.HalfNormal(prior_sigma))
    kernel2 = sigma2 ** 2 * kernels.Matern32(rho2)

    kernel = kernel1 + kernel2
    gp = GaussianProcess(kernel, t, diag=yerr ** 2 + jitter, mean=mean)

    numpyro.sample(
        "gp",
        dist.MultivariateNormal(loc=gp.loc, scale_tril=gp.scale_tril),
        obs=y,
    )

    if y is not None:
        numpyro.deterministic("pred", gp.predict(y, true_t))


nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = random.PRNGKey(34923)
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/latest/lib/python3.7/site-packages/numpyro/infer/mcmc.py:280: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  self.num_chains, local_device_count(), self.num_chains
%%time
mcmc.run(rng_key, t, yerr, y=y)
samples = mcmc.get_samples()
pred = samples["pred"].block_until_ready()  # Blocking to get timing right
CPU times: user 2min 2s, sys: 37.6 s, total: 2min 40s
Wall time: 1min 26s
q = np.percentile(pred, [5, 50, 95], axis=0)
plt.fill_between(true_t, q[0], q[2], color="C0", alpha=0.5)
plt.plot(true_t, q[1], color="C0", lw=2)

plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
_ = plt.title("posterior inference")
../_images/numpyro_5_0.png
import arviz as az

data = az.from_numpyro(mcmc)
az.summary(
    data, var_names=[v for v in data.posterior.data_vars if v != "pred"]
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
jitter 0.005 0.005 0.000 0.014 0.000 0.000 1306.0 929.0 1.0
mean 0.016 1.490 -3.067 2.768 0.055 0.041 832.0 744.0 1.0
rho1 2.172 0.402 1.798 2.662 0.026 0.020 978.0 334.0 1.0
rho2 7.804 3.262 2.017 13.464 0.092 0.065 1163.0 1056.0 1.0
sigma1 0.967 0.443 0.410 1.701 0.016 0.012 1157.0 731.0 1.0
sigma2 1.759 0.968 0.475 3.588 0.032 0.023 969.0 1049.0 1.0
tau 2.436 1.014 0.675 4.111 0.039 0.028 643.0 497.0 1.0