Tutorial: Modeling Frameworks

try:
    import tinygp
except ImportError:
    %pip install -q tinygp
    
try:
    import numpyro
except ImportError:
    %pip uninstall -y jax jaxlib
    %pip install -q numpyro jax jaxlib

try:
    import arviz
except ImportError:
    %pip install arviz
    
try:
    import flax
except ImportError:
    %pip install -q flax
    
try:
    import optax
except ImportError:
    %pip install -q optax
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Tutorial: Modeling Frameworks

One of the key design decisions made in the tinygp API is that it shouldn’t place strong constraints on your modeling on inference choices. Most existing Python-based GP libraries require a large amount of buy in from users, and we wanted to avoid that here. That being said, you will be required to buy into jax as your computational backend, but there exists a rich ecosystem of modeling frameworks that should all be compatible with tinygp. In this tutorial, we demonstrate how you might use tinygp combined with some popular jax-based modeling frameworks:

  1. Optimization with flax & optax, and

  2. Sampling with numpyro.

Similar examples should be possible with other libraries like TensorFlow Probability, PyMC (version > 4.0), mcx, or BlackJAX, to name a few.

To begin with, let’s simulate a dataset that we can use for our examples:

import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(42)

t = np.sort(
    np.append(
        random.uniform(0, 3.8, 28),
        random.uniform(5.5, 10, 18),
    )
)
yerr = random.uniform(0.08, 0.22, len(t))
y = (
    0.2 * (t - 5)
    + np.sin(3 * t + 0.1 * (t - 5) ** 2)
    + yerr * random.normal(size=len(t))
)

true_t = np.linspace(0, 10, 100)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
_ = plt.title("simulated data")
../_images/modeling_2_0.png

Optimization with flax & optax

Using our simulated dataset from above, we may want to find the maximum (marginal) likelihood hyperparameters for a GP model. One popular modeling framework that we can use for this task is flax. A benefit of integrating with flax is that we can then easily combine our GP model with other machine learning models for all sorts of fun results (see Example: Deep kernel lerning, for example).

To set up our model, we define a custom linen.Module, and optimize it’s parameters as follows:

from tinygp import kernels, GaussianProcess

import jax
import jax.numpy as jnp

import flax.linen as nn
from flax.linen.initializers import zeros

import optax


class GPModule(nn.Module):
    @nn.compact
    def __call__(self, x, yerr, y, t):
        mean = self.param("mean", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())

        log_sigma1 = self.param("log_sigma1", zeros, ())
        log_rho1 = self.param("log_rho1", zeros, ())
        log_tau = self.param("log_tau", zeros, ())
        kernel1 = (
            jnp.exp(2 * log_sigma1)
            * kernels.ExpSquared(jnp.exp(log_tau))
            * kernels.Cosine(jnp.exp(log_rho1))
        )

        log_sigma2 = self.param("log_sigma2", zeros, ())
        log_rho2 = self.param("log_rho2", zeros, ())
        kernel2 = jnp.exp(2 * log_sigma2) * kernels.Matern32(jnp.exp(log_rho2))

        kernel = kernel1 + kernel2
        gp = GaussianProcess(
            kernel, x, diag=yerr ** 2 + jnp.exp(log_jitter), mean=mean
        )

        loss = -gp.condition(y)
        pred = gp.predict(y, t)

        return loss, pred


def loss(params):
    return model.apply(params, t, yerr, y, true_t)[0]


model = GPModule()
params = model.init(jax.random.PRNGKey(0), t, yerr, y, true_t)
tx = optax.sgd(learning_rate=3e-3)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

losses = []
for i in range(1001):
    loss_val, grads = loss_grad_fn(params)
    losses.append(loss_val)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

plt.plot(losses)
plt.ylabel("negative log likelihood")
_ = plt.xlabel("step number")
../_images/modeling_4_0.png

Our Module defined above also returns the conditional predictions, that we can compare to the true model:

pred = model.apply(params, t, yerr, y, true_t)[1]

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3, label="truth")
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.plot(true_t, pred, label="max likelihood model")
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
plt.legend()
_ = plt.title("maximum likelihood")
../_images/modeling_6_0.png

Sampling with numpyro

Perhaps we’re not satisfied with just a point estimate of our hyperparameters and we want to instead compute posterior expectations. One tool for doing that is numpyro, which offers Markov chain Monte Carlo (MCMC) and variational inference methods. As a demonstration, here’s how we would set up the model from above and run MCMC in numpyro:

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

prior_sigma = 5.0


def numpyro_model(t, yerr, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
    jitter = numpyro.sample("jitter", dist.HalfNormal(prior_sigma))

    sigma1 = numpyro.sample("sigma1", dist.HalfNormal(prior_sigma))
    rho1 = numpyro.sample("rho1", dist.HalfNormal(prior_sigma))
    tau = numpyro.sample("tau", dist.HalfNormal(prior_sigma))
    kernel1 = sigma1 ** 2 * kernels.ExpSquared(tau) * kernels.Cosine(rho1)

    sigma2 = numpyro.sample("sigma2", dist.HalfNormal(prior_sigma))
    rho2 = numpyro.sample("rho2", dist.HalfNormal(prior_sigma))
    kernel2 = sigma2 ** 2 * kernels.Matern32(rho2)

    kernel = kernel1 + kernel2
    gp = GaussianProcess(kernel, t, diag=yerr ** 2 + jitter, mean=mean)
    numpyro.sample("gp", gp.numpyro_dist(), obs=y)

    if y is not None:
        numpyro.deterministic("pred", gp.predict(y, true_t))


nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = jax.random.PRNGKey(34923)
/home/docs/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/ipykernel_launcher.py:35: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
%%time
mcmc.run(rng_key, t, yerr, y=y)
samples = mcmc.get_samples()
pred = samples["pred"].block_until_ready()  # Blocking to get timing right
CPU times: user 1min 9s, sys: 20 s, total: 1min 29s
Wall time: 49.7 s

When running iterative methods like MCMC, it’s always a good idea to check some convergence diagnostics. For that task, let’s use ArviZ:

import arviz as az

data = az.from_numpyro(mcmc)
az.summary(
    data, var_names=[v for v in data.posterior.data_vars if v != "pred"]
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
jitter 0.003 0.003 0.000 0.009 0.000 0.000 1157.0 757.0 1.00
mean 0.020 1.517 -2.963 3.040 0.057 0.044 789.0 555.0 1.01
rho1 2.251 0.461 1.781 2.845 0.029 0.021 617.0 356.0 1.00
rho2 7.355 3.076 1.978 13.116 0.085 0.060 1158.0 794.0 1.00
sigma1 1.078 0.551 0.465 1.976 0.019 0.015 1420.0 828.0 1.00
sigma2 1.783 1.037 0.465 3.759 0.034 0.024 916.0 1122.0 1.00
tau 2.263 1.086 0.725 4.184 0.043 0.030 501.0 465.0 1.01

And, finally we can plot our posterior inferences of the comditional process, compared to the true model:

q = np.percentile(pred, [5, 50, 95], axis=0)
plt.fill_between(true_t, q[0], q[2], color="C0", alpha=0.5, label="inference")
plt.plot(true_t, q[1], color="C0", lw=2)
plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3, label="truth")

plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
plt.legend()
_ = plt.title("posterior inference")
../_images/modeling_13_0.png